diff --git a/.gitignore b/.gitignore index 47ff034..7eec5c1 100644 --- a/.gitignore +++ b/.gitignore @@ -105,7 +105,7 @@ celerybeat.pid # Environments .env .venv -env/ +# env/ venv/ ENV/ env.bak/ diff --git a/robot_infra/env/__init__.py b/robot_infra/env/__init__.py new file mode 100644 index 0000000..06558f0 --- /dev/null +++ b/robot_infra/env/__init__.py @@ -0,0 +1,5 @@ +from env.franka_robotiq_env import FrankaRobotiq +from env.franka_pcb_env import PCBEnv +from env.franka_cable_env import RouteCableEnv +from env.franka_cable_env import ResetCableEnv +from env.franka_bin_pick_env import BinPickEnv \ No newline at end of file diff --git a/robot_infra/env/franka_bin_pick_env.py b/robot_infra/env/franka_bin_pick_env.py new file mode 100644 index 0000000..462f5a4 --- /dev/null +++ b/robot_infra/env/franka_bin_pick_env.py @@ -0,0 +1,300 @@ +import gym +from gym import spaces +import numpy as np +from franka_robotiq_env import FrankaRobotiq +import time +from scipy.spatial.transform import Rotation +import requests +import copy +import cv2 +from camera.video_capture import VideoCapture +from camera.rs_capture import RSCapture +import queue + +class BinPickEnv(FrankaRobotiq): + def __init__(self): + super().__init__() + # Bouding box + self.xyz_bounding_box = gym.spaces.Box( + np.array((0.44, -0.12, 0.04)), np.array((0.53, 0.12, 0.1)), dtype=np.float64 + ) + self.rpy_bounding_box = gym.spaces.Box( + # np.array((np.pi-0.001, 0-0.001, np.pi/4)), + # np.array((np.pi+0.001, 0+0.001, 3*np.pi/4)), + np.array((np.pi-0.001, 0-0.001, 0-0.01)), + np.array((np.pi+0.001, 0+0.001, 0+0.01)), + dtype=np.float64, + ) + self.inner_box = gym.spaces.Box( + np.array([0.44, -0.04, 0.04]), + np.array([0.53, 0.04, 0.08]), + dtype=np.float64 + ) + self.drop_box = gym.spaces.Box( + np.array([0.44, -0.04]), + np.array([0.53, 0.04]), + dtype=np.float64 + ) + ## Action/Observation Space + self.action_space = gym.spaces.Box( + np.array((-0.03, -0.03, -0.03, -0.05, -0.05, -0.2, -1)), + np.array((0.03, 0.03, 0.03, 0.05, 0.05, 0.2, 1)), + ) + # enable gripper in observation space + self.observation_space['state_observation']['gripper_pose'] = spaces.Box(-np.inf, np.inf, shape=(1,)) + self.centerpos = copy.deepcopy(self.resetpos) + self.centerpos[:3] = np.mean((self.xyz_bounding_box.high, self.xyz_bounding_box.low), axis=0) #np.array([0.55,-0.05,0.09]) + self.centerpos[2] += 0.01 + self.resetpos = copy.deepcopy(self.centerpos) + self.resetpos[3:] = self.euler_2_quat(np.pi, 0., 0) + + def go_to_rest(self, jpos=False): + count = 0 + requests.post(self.url + "precision_mode") + if jpos: + restp_new = copy.deepcopy(self.currpos) + restp_new[2] = 0.3 + dp = restp_new - self.currpos + count_1 = 0 + self._send_pos_command(self.currpos) + requests.post(self.url + "precision_mode") + while ( + (np.linalg.norm(dp[:3]) > 0.03 or np.linalg.norm(dp[3:]) > 0.04) + ) and count_1 < 50: + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.03: + dp[:3] = 0.03 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count_1 += 1 + + print("JOINT RESET") + requests.post(self.url + "jointreset") + else: + # print("RESET") + self.update_currpos() + restp = copy.deepcopy(self.resetpos[:]) + if self.randomreset: + restp[:2] += np.random.uniform(-0.005, 0.005, (2,)) + restp[2] += np.random.uniform(-0.005, 0.005, (1,)) + # restyaw += np.random.uniform(-np.pi / 6, np.pi / 6) + # restp[3:] = self.euler_2_quat(np.pi, 0, restyaw) + + restp_new = copy.deepcopy(restp) + restp_new[2] = 0.13 #cable + dp = restp_new - self.currpos + + height = np.zeros_like(self.resetpos) + height[2] = 0.02 + while count < 10: + self._send_pos_command(self.currpos + height) + time.sleep(0.1) + self.update_currpos() + count += 1 + + count = 0 + while count < 200 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.03 + ): + if np.linalg.norm(dp[3:]) > 0.02: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count += 1 + + dp = restp - self.currpos + count = 0 + while count < 20 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.01 + ): + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp - self.currpos + count += 1 + requests.post(self.url + "peg_compliance_mode") + return count < 50 + + def get_im(self): + images = {} + for key, cap in self.cap.items(): + try: + rgb = cap.read() + # images[key] = cv2.resize(rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + if key == 'wrist_1': + # cropped_rgb = rgb[ 100:400, 50:350, :] + cropped_rgb = rgb[:, 80:560, :] + if key == 'wrist_2': + # cropped_rgb = rgb[ 50:350, 200:500, :] #150:450 + cropped_rgb = rgb[:, 80:560, :] + # if key == 'side_1': + # cropped_rgb = rgb[150:330, 230:410, :] + + images[key] = cv2.resize(cropped_rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + # images[key] = cv2.resize(rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + images[key + "_full"] = rgb + # images[f"{key}_depth"] = depth + except queue.Empty: + input(f'{key} camera frozen. Check connect, then press enter to relaunch...') + cap.close() + # if key == 'side_1': + # cap = RSCapture(name='side_1', serial_number='128422270679', depth=True) + # elif key == 'side_2': + # cap = RSCapture(name='side_2', serial_number='127122270146', depth=True) + if key == 'wrist_1': + cap = RSCapture(name='wrist_1', serial_number='130322274175', depth=False) + elif key == 'wrist_2': + # cap = RSCapture(name='wrist_2', serial_number='127122270572', depth=False) + cap = RSCapture(name='wrist_2', serial_number='127122270572', depth=False) + elif key == 'side_1': + cap = RSCapture(name='side_1', serial_number='128422272758', depth=False) + else: + raise KeyError + self.cap[key] = VideoCapture(cap) + return self.get_im() + + self.img_queue.put(images) + return images + + def clip_safety_box(self, pose): + pose[:3] = np.clip( + pose[:3], self.xyz_bounding_box.low, self.xyz_bounding_box.high + ) + + euler = Rotation.from_quat(pose[3:]).as_euler("xyz") + old_sign = np.sign(euler[0]) + euler[0] = ( + np.clip( + euler[0] * old_sign, + self.rpy_bounding_box.low[0], + self.rpy_bounding_box.high[0], + ) + * old_sign + ) + euler[1:] = np.clip( + euler[1:], self.rpy_bounding_box.low[1:], self.rpy_bounding_box.high[1:] + ) + pose[3:] = Rotation.from_euler("xyz", euler).as_quat() + + # Clip xyz to inner box + if self.inner_box.contains(pose[:3]): + print(f'Command: {pose[:3]}') + pose[:3] = self.intersect_line_bbox(self.currpos[:3], pose[:3], self.inner_box.low, self.inner_box.high) + print(f'Clipped: {pose[:3]}') + + return pose + + def intersect_line_bbox(self, p1, p2, bbox_min, bbox_max): + # Define the parameterized line segment + # P(t) = p1 + t(p2 - p1) + tmin = 0 + tmax = 1 + + for i in range(3): + if p1[i] < bbox_min[i] and p2[i] < bbox_min[i]: + return None + if p1[i] > bbox_max[i] and p2[i] > bbox_max[i]: + return None + + # For each axis (x, y, z), compute t values at the intersection points + if abs(p2[i] - p1[i]) > 1e-10: # To prevent division by zero + t1 = (bbox_min[i] - p1[i]) / (p2[i] - p1[i]) + t2 = (bbox_max[i] - p1[i]) / (p2[i] - p1[i]) + + # Ensure t1 is smaller than t2 + if t1 > t2: + t1, t2 = t2, t1 + + tmin = max(tmin, t1) + tmax = min(tmax, t2) + + if tmin > tmax: + return None + + # Compute the intersection point using the t value + intersection = p1 + tmin * (p2 - p1) + + return intersection + + def step(self, action): + start_time = time.time() + action = np.clip(action, self.action_space.low, self.action_space.high) + if self.actionnoise > 0: + a = action[:3] + np.random.uniform( + -self.actionnoise, self.actionnoise, (3,) + ) + else: + a = action[:3] + + self.nextpos = self.currpos.copy() + self.nextpos[:3] = self.nextpos[:3] + a + + ### GET ORIENTATION FROM ACTION + self.nextpos[3:] = ( + Rotation.from_euler("xyz", action[3:6]) + * Rotation.from_quat(self.currpos[3:]) + ).as_quat() + + gripper = action[-1] + if gripper > 0: + if not self.drop_box.contains(self.currpos[:2]): + gripper = (self.currgrip + 1) % 2 + self.set_gripper(gripper) + + self._send_pos_command(self.clip_safety_box(self.nextpos)) + + self.curr_path_length += 1 + dl = time.time() - start_time + + time.sleep(max(0, (1.0 / self.hz) - dl)) + + self.update_currpos() + ob = self._get_obs() + obs_xyz = ob['state_observation']['tcp_pose'][:3] + obs_rpy = ob['state_observation']['tcp_pose'][3:] + reward = 0 + done = self.curr_path_length >= 40 #100 + # if not self.xyz_bounding_box.contains(obs_xyz) or not self.rpy_bounding_box.contains(obs_rpy): + # # print('Truncated: Bouding Box') + # print("xyz: ", self.xyz_bounding_box.contains(obs_xyz), obs_xyz) + # print("rortate: ", self.rpy_bounding_box.contains(obs_rpy), obs_rpy) + # return ob, 0, True, True, {} + return ob, int(reward), done, done, {} + + def reset(self, jpos=False, gripper=None, require_input=False): + self.cycle_count += 1 + if self.cycle_count % 1500 == 0: + self.cycle_count = 0 + jpos=True + + success = self.go_to_rest(jpos=jpos) + self.update_currpos() + self.curr_path_length = 0 + self.recover() + if jpos == True: + self.go_to_rest(jpos=False) + self.update_currpos() + self.recover() + + if require_input: + input("Reset Environment, Press Enter Once Complete: ") + # print("RESET COMPLETE") + requests.post(self.url + "open") + self.currgrip = 0 + time.sleep(1) + + self.update_currpos() + # self.last_quat = self.currpos[3:] + o = self._get_obs() + return o, {} \ No newline at end of file diff --git a/robot_infra/env/franka_cable_env.py b/robot_infra/env/franka_cable_env.py new file mode 100644 index 0000000..ad0d81f --- /dev/null +++ b/robot_infra/env/franka_cable_env.py @@ -0,0 +1,247 @@ +import gym +from gym import spaces +import numpy as np +from franka_robotiq_env import FrankaRobotiq +import time +from scipy.spatial.transform import Rotation +import requests +import copy +import cv2 +from camera.video_capture import VideoCapture +from camera.rs_capture import RSCapture +import queue + +class RouteCableEnv(FrankaRobotiq): + def __init__(self): + super().__init__() + # Bouding box + self.xyz_bounding_box = gym.spaces.Box( + np.array((0.51, -0.1, 0.04)), np.array((0.59, 0, 0.12)), dtype=np.float64 + ) + self.rpy_bounding_box = gym.spaces.Box( + np.array((np.pi-0.001, 0-0.001, np.pi/4)), + np.array((np.pi+0.001, 0+0.001, 3*np.pi/4)), + dtype=np.float64, + ) + ## Action/Observation Space + self.action_space = gym.spaces.Box( + np.array((-0.02, -0.02, -0.02, -0.05, -0.05, -0.1, -1)), + np.array((0.02, 0.02, 0.02, 0.05, 0.05, 0.1, 1)), + ) + # enable gripper in observation space + self.observation_space['state_observation']['gripper_pose'] = spaces.Box(-np.inf, np.inf, shape=(1,)) + # [0.48012088982197254,-0.07218941280725254,0.11078303293108258,0.6995269546628874,0.7134059993136379,0.028532587996196627,0.029996854262000595] + self.resetpos[:3] = np.array([0.55,-0.05,0.09]) + self.resetpos[3:] = self.euler_2_quat(np.pi, 0.03, np.pi/2) + + def go_to_rest(self, jpos=False): + count = 0 + requests.post(self.url + "precision_mode") + if jpos: + restp_new = copy.deepcopy(self.currpos) + restp_new[2] = 0.3 + dp = restp_new - self.currpos + count_1 = 0 + self._send_pos_command(self.currpos) + requests.post(self.url + "precision_mode") + while ( + (np.linalg.norm(dp[:3]) > 0.03 or np.linalg.norm(dp[3:]) > 0.04) + ) and count_1 < 50: + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.03: + dp[:3] = 0.03 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count_1 += 1 + + print("JOINT RESET") + requests.post(self.url + "jointreset") + else: + # print("RESET") + self.update_currpos() + restp = copy.deepcopy(self.resetpos[:]) + if self.randomreset: + restp[:2] += np.random.uniform(-0.005, 0.005, (2,)) + restp[2] += np.random.uniform(-0.005, 0.005, (1,)) + # restyaw += np.random.uniform(-np.pi / 6, np.pi / 6) + # restp[3:] = self.euler_2_quat(np.pi, 0, restyaw) + + restp_new = copy.deepcopy(restp) + restp_new[2] = 0.15 #cable + dp = restp_new - self.currpos + + height = np.zeros_like(self.resetpos) + height[2] = 0.02 + while count < 10: + self._send_pos_command(self.currpos + height) + time.sleep(0.1) + self.update_currpos() + count += 1 + + count = 0 + while count < 200 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.03 + ): + if np.linalg.norm(dp[3:]) > 0.02: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count += 1 + + dp = restp - self.currpos + count = 0 + while count < 20 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.01 + ): + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp - self.currpos + count += 1 + requests.post(self.url + "peg_compliance_mode") + return count < 50 + + def get_im(self): + images = {} + for key, cap in self.cap.items(): + try: + rgb = cap.read() + # images[key] = cv2.resize(rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + if key == 'wrist_1': + # cropped_rgb = rgb[ 100:400, 50:350, :] + cropped_rgb = rgb[:, 80:560, :] + if key == 'wrist_2': + # cropped_rgb = rgb[ 50:350, 200:500, :] #150:450 + cropped_rgb = rgb[:, 80:560, :] + # if key == 'side_1': + # cropped_rgb = rgb[150:330, 230:410, :] + + images[key] = cv2.resize(cropped_rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + # images[key] = cv2.resize(rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + images[key + "_full"] = rgb + # images[f"{key}_depth"] = depth + except queue.Empty: + input(f'{key} camera frozen. Check connect, then press enter to relaunch...') + cap.close() + # if key == 'side_1': + # cap = RSCapture(name='side_1', serial_number='128422270679', depth=True) + # elif key == 'side_2': + # cap = RSCapture(name='side_2', serial_number='127122270146', depth=True) + if key == 'wrist_1': + cap = RSCapture(name='wrist_1', serial_number='130322274175', depth=False) + elif key == 'wrist_2': + # cap = RSCapture(name='wrist_2', serial_number='127122270572', depth=False) + cap = RSCapture(name='wrist_2', serial_number='127122270572', depth=False) + elif key == 'side_1': + cap = RSCapture(name='side_1', serial_number='128422272758', depth=False) + else: + raise KeyError + self.cap[key] = VideoCapture(cap) + return self.get_im() + + self.img_queue.put(images) + return images + + def step(self, action): + start_time = time.time() + action = np.clip(action, self.action_space.low, self.action_space.high) + if self.actionnoise > 0: + a = action[:3] + np.random.uniform( + -self.actionnoise, self.actionnoise, (3,) + ) + else: + a = action[:3] + + self.nextpos = self.currpos.copy() + self.nextpos[:3] = self.nextpos[:3] + a + + ### GET ORIENTATION FROM ACTION + self.nextpos[3:] = ( + Rotation.from_euler("xyz", action[3:6]) + * Rotation.from_quat(self.currpos[3:]) + ).as_quat() + + self._send_pos_command(self.clip_safety_box(self.nextpos)) + # only change the gripper if the action is above a threshold, either open or close + if len(action) == 7: + if action[-1] > 0.8: + self.set_gripper(1) + elif action[-1] < -0.8: + self.set_gripper(0) + + self.curr_path_length += 1 + dl = time.time() - start_time + + time.sleep(max(0, (1.0 / self.hz) - dl)) + + self.update_currpos() + ob = self._get_obs() + obs_xyz = ob['state_observation']['tcp_pose'][:3] + obs_rpy = ob['state_observation']['tcp_pose'][3:] + reward = 0 + done = self.curr_path_length >= 30 #100 + # if not self.xyz_bounding_box.contains(obs_xyz) or not self.rpy_bounding_box.contains(obs_rpy): + # # print('Truncated: Bouding Box') + # print("xyz: ", self.xyz_bounding_box.contains(obs_xyz), obs_xyz) + # print("rortate: ", self.rpy_bounding_box.contains(obs_rpy), obs_rpy) + # return ob, 0, True, True, {} + return ob, int(reward), done, done, {} + + def reset(self, jpos=False, gripper=None, require_input=False): + self.cycle_count += 1 + if self.cycle_count % 1500 == 0: + self.cycle_count = 0 + jpos=True + # requests.post(self.url + "reset_gripper") + # time.sleep(3) + # self.set_gripper(self.start_gripper, block=False) + self.currgrip = self.start_gripper + + success = self.go_to_rest(jpos=jpos) + self.update_currpos() + self.curr_path_length = 0 + self.recover() + if jpos == True: + self.go_to_rest(jpos=False) + self.update_currpos() + self.recover() + + if require_input: + input("Reset Environment, Press Enter Once Complete: ") + # print("RESET COMPLETE") + self.update_currpos() + # self.last_quat = self.currpos[3:] + o = self._get_obs() + return o, {} + + +class ResetCableEnv(FrankaRobotiq): + def __init__(self): + super().__init__() + # Bouding box + self.xyz_bounding_box = gym.spaces.Box( + np.array((0.62, 0.0, 0.05)), np.array((0.71, 0.08, 0.3)), dtype=np.float64 + ) + self.rpy_bounding_box = gym.spaces.Box( + np.array((np.pi-0.1, -0.1, 1.35)), + np.array((np.pi+0.1, 0.1, 1.7)), + dtype=np.float64, + ) + ## Action/Observation Space + self.action_space = gym.spaces.Box( + np.array((-0.02, -0.02, -0.02, -0.05, -0.05, -0.05, 0 - 1e-8)), + np.array((0.02, 0.02, 0.02, 0.05, 0.05, 0.05, 1 + 1e-8)), + ) + # self.resetpos[:3] = np.array([0.645, 0.17, 0.07]) + # self.resetpos[3:] = self.euler_2_quat(np.pi, 0.03, 0) \ No newline at end of file diff --git a/robot_infra/env/franka_pcb_env.py b/robot_infra/env/franka_pcb_env.py new file mode 100644 index 0000000..df0d716 --- /dev/null +++ b/robot_infra/env/franka_pcb_env.py @@ -0,0 +1,130 @@ +import gym +from gym import spaces +import numpy as np +# from franka.scripts.spacemouse_teleop import SpaceMouseExpert +import time +from franka_robotiq_env import FrankaRobotiq +import copy +import requests + +class PCBEnv(FrankaRobotiq): + def __init__(self): + + super().__init__() + self._TARGET_POSE = [0.6479450830785974,0.17181947852969695,0.056419218166284224, 3.1415, 0.0, 0.0 ] + self._REWARD_THRESHOLD = [0.005, 0.005, 0.0006, 0.03, 0.03, 0.05] + self.observation_space = spaces.Dict( + { + "state_observation": spaces.Dict( + { + "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(6,)), # xyz + euler + "tcp_vel": spaces.Box(-np.inf, np.inf, shape=(6,)), + "tcp_force": spaces.Box(-np.inf, np.inf, shape=(3,)), + "tcp_torque": spaces.Box(-np.inf, np.inf, shape=(3,)), + } + ), + "image_observation": spaces.Dict( + { + "wrist_1": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_1_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + "wrist_2": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_2_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + } + ), + } + ) + self.action_space = gym.spaces.Box( + np.array((-0.01, -0.01, -0.01, -0.05, -0.05, -0.05)), + np.array((0.01, 0.01, 0.01, 0.05, 0.05, 0.05)) + ) + self.xyz_bounding_box = gym.spaces.Box( + np.array((0.62, 0.15, 0.03)), + np.array((0.67, 0.19, 0.09)), + dtype=np.float64 + ) + self.rpy_bounding_box = gym.spaces.Box( + np.array((np.pi-0.15, -0.05, -0.1)), + np.array((np.pi+0.1, 0.15, 0.1)), + dtype=np.float64 + ) + self.resetpos[:3] = np.array([0.645, 0.17, 0.07]) + self.resetpos[3:] = self.euler_2_quat(np.pi, 0.03, 0) + self.episodes = 1 + self.randomreset = False + + def _get_state(self): + state = super()._get_state() + state.pop('gripper_pose') + return state + + def go_to_rest(self, jpos=False): + count = 0 + if self.currpos[2] < 0.06: + restp_new = copy.deepcopy(self.currpos) + restp_new[2] += 0.02 + dp = restp_new - self.currpos + while count < 200 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.03 + ): + if np.linalg.norm(dp[3:]) > 0.02: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count += 1 + + + requests.post(self.url + "precision_mode") + if jpos: + restp_new = copy.deepcopy(self.currpos) + restp_new[2] = 0.2 + dp = restp_new - self.currpos + count_1 = 0 + self._send_pos_command(self.currpos) + requests.post(self.url + "precision_mode") + while ( + (np.linalg.norm(dp[:3]) > 0.03 or np.linalg.norm(dp[3:]) > 0.04) + ) and count_1 < 50: + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.03: + dp[:3] = 0.03 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count_1 += 1 + + print("JOINT RESET") + requests.post(self.url + "jointreset") + else: + # print("RESET") + restp = copy.deepcopy(self.resetpos[:]) + if self.randomreset: + restp[:2] += np.random.uniform(-0.005, 0.005, (2,)) + restp[2] += np.random.uniform(-0.005, 0.005, (1,)) + # restyaw += np.random.uniform(-np.pi / 6, np.pi / 6) + # restp[3:] = self.euler_2_quat(np.pi, 0, restyaw) + + restp_new = copy.deepcopy(restp) + restp_new[2] = 0.07 #PCB + self.update_currpos() + dp = restp_new - self.currpos + while count < 200 and ( + np.linalg.norm(dp[:3]) > 0.005 or np.linalg.norm(dp[3:]) > 0.03 + ): + if np.linalg.norm(dp[3:]) > 0.02: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count += 1 + + requests.post(self.url + "pcb_compliance_mode") + return count < 200 \ No newline at end of file diff --git a/robot_infra/env/franka_robotiq_env.py b/robot_infra/env/franka_robotiq_env.py new file mode 100644 index 0000000..0a7540e --- /dev/null +++ b/robot_infra/env/franka_robotiq_env.py @@ -0,0 +1,472 @@ +"""Gym Interface for Franka""" +import numpy as np +import gym +from pyquaternion import Quaternion +import cv2 +import copy +from scipy.spatial.transform import Rotation +import time +import requests +from gym import core, spaces +from camera.video_capture import VideoCapture +from camera.rs_capture import RSCapture +import queue +from PIL import Image +from queue import Queue +import threading +import os + +class ImageDisplayer(threading.Thread): + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + self.stop_signal = False + self.daemon = True # make this a daemon thread + + self.video = [] + + video_dir = '/home/undergrad/franka_fwbw_pick_screw_vice_ckpts' + os.makedirs(video_dir, exist_ok=True) + uuid = time.strftime("%Y%m%d-%H%M%S") + self.wrist1 = cv2.VideoWriter(os.path.join(video_dir, f'wrist_1_{uuid}.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), 24, (640, 480)) + self.wrist2 = cv2.VideoWriter(os.path.join(video_dir, f'wrist_2_{uuid}.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), 24, (640, 480)) + self.frame_counter = 0 + + def run(self): + while True: + img_array = self.queue.get() # retrieve an image from the queue + if img_array is None: # None is our signal to exit + break + # pair1 = np.concatenate([img_array['wrist_1_full'], img_array['wrist_2_full']], axis=0) + pair1 = np.concatenate([img_array['wrist_1'], img_array['wrist_2']], axis=0) + # pair1 = np.concatenate([img_array['wrist_1'], img_array['wrist_2'], img_array['side_1']], axis=0) + # pair2 = np.concatenate([img_array['side_2_full'], img_array['side_1_full']], axis=0) + # concatenated = np.concatenate([pair1, pair2], axis=1) + cv2.imshow('wrist', pair1/255.) + cv2.waitKey(1) + + self.wrist1.write(img_array['wrist_1_full']) + self.wrist2.write(img_array['wrist_2_full']) + self.frame_counter += 1 + if self.frame_counter == 400: + self.wrist1.release() + self.wrist2.release() + + +class FrankaRobotiq(gym.Env): + def __init__( + self, + randomReset=False, + hz=10, + start_gripper=0, + ): + + self._TARGET_POSE = [0.6636488814118523,0.05388642290645651,0.09439445897864279, 3.1339503, 0.009167, 1.5550434] + self._REWARD_THRESHOLD = [0.01, 0.01, 0.01, 0.2, 0.2, 0.2] + self.resetpos = np.zeros(7) + + self.resetpos[:3] = self._TARGET_POSE[:3] + self.resetpos[2] += 0.07 + self.resetpos[3:] = self.euler_2_quat(self._TARGET_POSE[3], self._TARGET_POSE[4], self._TARGET_POSE[5]) + + self.currpos = self.resetpos.copy() + self.currvel = np.zeros((6,)) + self.q = np.zeros((7,)) + self.dq = np.zeros((7,)) + self.currforce = np.zeros((3,)) + self.currtorque = np.zeros((3,)) + self.currjacobian = np.zeros((6, 7)) + self.start_gripper = start_gripper + self.currgrip = self.start_gripper #start_gripper + self.lastsent = time.time() + self.randomreset = randomReset + self.actionnoise = 0 + self.hz = hz + + ## NUC + self.ip = "127.0.0.1" + self.url = "http://" + self.ip + ":5000/" + + # Bouding box + self.xyz_bounding_box = gym.spaces.Box( + np.array((0.62, 0.0, 0.05)), np.array((0.71, 0.08, 0.3)), dtype=np.float64 + ) + self.rpy_bounding_box = gym.spaces.Box( + np.array((np.pi-0.1, -0.1, 1.35)), + np.array((np.pi+0.1, 0.1, 1.7)), + dtype=np.float64, + ) + ## Action/Observation Space + self.action_space = gym.spaces.Box( + np.array((-0.02, -0.02, -0.02, -0.05, -0.05, -0.05, 0 - 1e-8)), + np.array((0.02, 0.02, 0.02, 0.05, 0.05, 0.05, 1 + 1e-8)), + ) + + self.observation_space = spaces.Dict( + { + "state_observation": spaces.Dict( + { + # "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(7,)), # xyz + quat + "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(6,)), # xyz + euler + "tcp_vel": spaces.Box(-np.inf, np.inf, shape=(6,)), + "gripper_pose": spaces.Box(-1, 1, shape=(1,)), + # "q": spaces.Box(-np.inf, np.inf, shape=(7,)), + # "dq": spaces.Box(-np.inf, np.inf, shape=(7,)), + "tcp_force": spaces.Box(-np.inf, np.inf, shape=(3,)), + "tcp_torque": spaces.Box(-np.inf, np.inf, shape=(3,)), + # "jacobian": spaces.Box(-np.inf, np.inf, shape=((6, 7))), + } + ), + "image_observation": spaces.Dict( + { + "wrist_1": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_1_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + "wrist_2": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_2_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + # "side_1": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + # "side_1_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + } + ), + } + ) + self.cycle_count = 0 + self.cap_wrist_1 = VideoCapture(RSCapture(name='wrist_1', serial_number='130322274175', depth=False)) + self.cap_wrist_2 = VideoCapture(RSCapture(name='wrist_2', serial_number='127122270572', depth=False)) + # self.cap_side_1 = VideoCapture(RSCapture(name='side_1', serial_number='128422272758', depth=False)) + + # self.cap_side_1 = VideoCapture( + # RSCapture(name="side_1", serial_number="128422270679", depth=True) + # ) + # self.cap_side_2 = VideoCapture( + # RSCapture(name="side_2", serial_number="127122270146", depth=True) + # ) + self.cap = { + # "side_1": self.cap_side_1, + # "side_2": self.cap_side_2, + "wrist_1": self.cap_wrist_1, + "wrist_2": self.cap_wrist_2, + } + + self.img_queue = queue.Queue() + self.displayer = ImageDisplayer(self.img_queue) + self.displayer.start() + print("Initialized Franka") + + def recover(self): + requests.post(self.url + "clearerr") + + def _send_pos_command(self, pos): + self.recover() + arr = np.array(pos).astype(np.float32) + data = {"arr": arr.tolist()} + requests.post(self.url + "pose", json=data) + + def update_currpos(self): + ps = requests.post(self.url + "getstate").json() + self.currpos[:] = np.array(ps["pose"]) + self.currvel[:] = np.array(ps["vel"]) + + self.currforce[:] = np.array(ps["force"]) + self.currtorque[:] = np.array(ps["torque"]) + self.currjacobian[:] = np.reshape(np.array(ps["jacobian"]), (6, 7)) + + self.q[:] = np.array(ps["q"]) + self.dq[:] = np.array(ps["dq"]) + + def set_gripper(self, position, block=True): + if position == 1: + st = 'close' + elif position == 0: + st = 'open' + else: + raise ValueError(f'Gripper position {position} not supported') + + ### IMPORTANT, IF FRANKA GRIPPER GETS OPEN/CLOSE COMMANDS TOO QUICKLY IT WILL FREEZE + ### THIS MAKES SURE CONSECUTIVE GRIPPER CHANGES ONLY HAPPEN 1 SEC APART + now = time.time() + delta = now - self.lastsent + if delta >= 1: + requests.post(self.url + st) + self.lastsent = time.time() + self.currgrip = position + # time.sleep(max(0, 1.5 - delta)) + + + def clip_safety_box(self, pose): + pose[:3] = np.clip( + pose[:3], self.xyz_bounding_box.low, self.xyz_bounding_box.high + ) + euler = Rotation.from_quat(pose[3:]).as_euler("xyz") + old_sign = np.sign(euler[0]) + euler[0] = ( + np.clip( + euler[0] * old_sign, + self.rpy_bounding_box.low[0], + self.rpy_bounding_box.high[0], + ) + * old_sign + ) + euler[1:] = np.clip( + euler[1:], self.rpy_bounding_box.low[1:], self.rpy_bounding_box.high[1:] + ) + + pose[3:] = Rotation.from_euler("xyz", euler).as_quat() + + return pose + + def move_to_pos(self, pos): + start_time = time.time() + self._send_pos_command(self.clip_safety_box(pos)) + dl = time.time() - start_time + time.sleep(max(0, (1.0 / self.hz) - dl)) + self.update_currpos() + obs = self._get_obs() + return obs + + def step(self, action): + start_time = time.time() + action = np.clip(action, self.action_space.low, self.action_space.high) + if self.actionnoise > 0: + a = action[:3] + np.random.uniform( + -self.actionnoise, self.actionnoise, (3,) + ) + else: + a = action[:3] + + self.nextpos = self.currpos.copy() + self.nextpos[:3] = self.nextpos[:3] + a + + ### GET ORIENTATION FROM ACTION + self.nextpos[3:] = ( + Rotation.from_euler("xyz", action[3:6]) + * Rotation.from_quat(self.currpos[3:]) + ).as_quat() + + # self.nextpos = self.clip_safety_box(self.nextpos) + # self._send_pos_command(self.nextpos) + self._send_pos_command(self.clip_safety_box(self.nextpos)) + # self.set_gripper(action[-1]) + + self.curr_path_length += 1 + dl = time.time() - start_time + + time.sleep(max(0, (1.0 / self.hz) - dl)) + + self.update_currpos() + ob = self._get_obs() + obs_xyz = ob['state_observation']['tcp_pose'][:3] + # obs_rpy = self.quat_2_euler(ob['state_observation']['tcp_pose'][3:7]) + obs_rpy = ob['state_observation']['tcp_pose'][3:] + reward = self.binary_reward_tcp(ob['state_observation']['tcp_pose']) + done = self.curr_path_length >= 100 + # if not self.xyz_bounding_box.contains(obs_xyz) or not self.rpy_bounding_box.contains(obs_rpy): + # # print('Truncated: Bouding Box') + # # print("xyz: ", self.xyz_bounding_box.contains(obs_xyz), obs_xyz) + # # print("rortate: ", self.rpy_bounding_box.contains(obs_rpy), obs_rpy) + # return ob, 0, True, True, {} + # return ob, int(reward), done or reward, done, {} + return ob, int(reward), done, done, {} + + + def binary_reward_tcp(self, current_pose,): + # euler_angles = np.abs(R.from_quat(current_pose[3:]).as_euler("xyz")) + euler_angles = np.abs(current_pose[3:]) + current_pose = np.hstack([current_pose[:3],euler_angles]) + delta = np.abs(current_pose - self._TARGET_POSE) + if np.all(delta < self._REWARD_THRESHOLD): + return True + else: + # print(f'Goal not reached, the difference is {delta}, the desired threshold is {_REWARD_THRESHOLD}') + return False + + def get_im(self): + images = {} + for key, cap in self.cap.items(): + try: + rgb = cap.read() + # images[key] = cv2.resize(rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + # if key == 'wrist_1': + # cropped_rgb = rgb[ 0:300, 150:450, :] + # if key == 'wrist_2': + # cropped_rgb = rgb[ 50:350, 150:450, :] + if key == 'wrist_1': + cropped_rgb = rgb[:, 80:560, :] + if key == 'wrist_2': + cropped_rgb = rgb[:, 80:560, :] + images[key] = cv2.resize(cropped_rgb, self.observation_space['image_observation'][key].shape[:2][::-1]) + images[key + "_full"] = rgb + # images[f"{key}_depth"] = depth + except queue.Empty: + input(f'{key} camera frozen. Check connect, then press enter to relaunch...') + cap.close() + # if key == 'side_1': + # cap = RSCapture(name='side_1', serial_number='128422270679', depth=True) + # elif key == 'side_2': + # cap = RSCapture(name='side_2', serial_number='127122270146', depth=True) + if key == 'wrist_1': + cap = RSCapture(name='wrist_1', serial_number='130322274175', depth=False) + elif key == 'wrist_2': + cap = RSCapture(name='wrist_2', serial_number='127122270572', depth=False) + else: + raise KeyError + self.cap[key] = VideoCapture(cap) + return self.get_im() + + self.img_queue.put(images) + return images + + def _get_state(self): + state_observation = { + "tcp_pose": np.concatenate((self.currpos[:3], self.quat_2_euler(self.currpos[3:]))), + "tcp_vel": self.currvel, + "gripper_pose": self.currgrip, + # "q": self.q, + # "dq": self.dq, + "tcp_force": self.currforce, + "tcp_torque": self.currtorque, + # "jacobian": self.currjacobian, + } + return state_observation + + def _get_obs(self): + images = self.get_im() + state_observation = self._get_state() + + return copy.deepcopy(dict( + image_observation=images, + state_observation=state_observation + )) + + def go_to_rest(self, jpos=False): + count = 0 + requests.post(self.url + "precision_mode") + if jpos: + restp_new = copy.deepcopy(self.currpos) + restp_new[2] = 0.3 + dp = restp_new - self.currpos + count_1 = 0 + self._send_pos_command(self.currpos) + requests.post(self.url + "precision_mode") + while ( + (np.linalg.norm(dp[:3]) > 0.03 or np.linalg.norm(dp[3:]) > 0.04) + ) and count_1 < 50: + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.03: + dp[:3] = 0.03 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count_1 += 1 + + print("JOINT RESET") + requests.post(self.url + "jointreset") + else: + # print("RESET") + self.update_currpos() + restp = copy.deepcopy(self.resetpos[:]) + if self.randomreset: + restp[:2] += np.random.uniform(-0.005, 0.005, (2,)) + restp[2] += np.random.uniform(-0.005, 0.005, (1,)) + # restyaw += np.random.uniform(-np.pi / 6, np.pi / 6) + # restp[3:] = self.euler_2_quat(np.pi, 0, restyaw) + + restp_new = copy.deepcopy(restp) + restp_new[2] = 0.2 #PEG + dp = restp_new - self.currpos + while count < 200 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.03 + ): + if np.linalg.norm(dp[3:]) > 0.02: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp_new - self.currpos + count += 1 + + dp = restp - self.currpos + count = 0 + while count < 20 and ( + np.linalg.norm(dp[:3]) > 0.01 or np.linalg.norm(dp[3:]) > 0.01 + ): + if np.linalg.norm(dp[3:]) > 0.05: + dp[3:] = 0.05 * dp[3:] / np.linalg.norm(dp[3:]) + if np.linalg.norm(dp[:3]) > 0.02: + dp[:3] = 0.02 * dp[:3] / np.linalg.norm(dp[:3]) + self._send_pos_command(self.currpos + dp) + time.sleep(0.1) + self.update_currpos() + dp = restp - self.currpos + count += 1 + requests.post(self.url + "peg_compliance_mode") + return count < 50 + + def reset(self, jpos=False, gripper=None, require_input=False): + self.cycle_count += 1 + if self.cycle_count % 150 == 0: + self.cycle_count = 0 + jpos=True + # requests.post(self.url + "reset_gripper") + # time.sleep(3) + # self.set_gripper(self.start_gripper, block=False) + self.currgrip = self.start_gripper + + success = self.go_to_rest(jpos=jpos) + self.update_currpos() + self.curr_path_length = 0 + self.recover() + if jpos == True: + self.go_to_rest(jpos=False) + self.update_currpos() + self.recover() + + if require_input: + input("Reset Environment, Press Enter Once Complete: ") + # print("RESET COMPLETE") + self.update_currpos() + # self.last_quat = self.currpos[3:] + o = self._get_obs() + return o, {} + + def quat_2_euler(self, quat): + # calculates and returns: yaw, pitch, roll from given quaternion + if not isinstance(quat, Quaternion): + quat = Quaternion(quat) + yaw, pitch, roll = quat.yaw_pitch_roll + return yaw + np.pi, pitch, roll + + def euler_2_quat(self, yaw=np.pi / 2, pitch=0.0, roll=np.pi): + yaw = np.pi - yaw + yaw_matrix = np.array( + [ + [np.cos(yaw), -np.sin(yaw), 0.0], + [np.sin(yaw), np.cos(yaw), 0.0], + [0, 0, 1.0], + ] + ) + pitch_matrix = np.array( + [ + [np.cos(pitch), 0.0, np.sin(pitch)], + [0.0, 1.0, 0.0], + [-np.sin(pitch), 0, np.cos(pitch)], + ] + ) + roll_matrix = np.array( + [ + [1.0, 0, 0], + [0, np.cos(roll), -np.sin(roll)], + [0, np.sin(roll), np.cos(roll)], + ] + ) + rot_mat = yaw_matrix.dot(pitch_matrix.dot(roll_matrix)) + return Quaternion(matrix=rot_mat).elements + + def close_camera(self): + # self.cap_top.close() + # self.cap_side.close() + self.cap_wrist_2.close() + self.cap_wrist_1.close() \ No newline at end of file diff --git a/robot_infra/env/wrappers.py b/robot_infra/env/wrappers.py new file mode 100644 index 0000000..af95a0b --- /dev/null +++ b/robot_infra/env/wrappers.py @@ -0,0 +1,208 @@ +import time +from gym import Env, spaces +import gym +import numpy as np +from gym.spaces import Box +import copy +from robot_infra.spacemouse.spacemouse_teleop import SpaceMouseExpert + + +class ProxyEnv(Env): + def __init__(self, wrapped_env): + self._wrapped_env = wrapped_env + self.action_space = self._wrapped_env.action_space + self.observation_space = self._wrapped_env.observation_space + + @property + def wrapped_env(self): + return self._wrapped_env + + def reset(self, **kwargs): + return self._wrapped_env.reset(**kwargs) + + def step(self, action): + return self._wrapped_env.step(action) + + def render(self, *args, **kwargs): + return self._wrapped_env.render(*args, **kwargs) + + @property + def horizon(self): + return self._wrapped_env.horizon + + def terminate(self): + if hasattr(self.wrapped_env, "terminate"): + self.wrapped_env.terminate() + + def seed(self, _seed): + return self.wrapped_env.seed(_seed) + + def __getattr__(self, attr): + if attr == '_wrapped_env': + raise AttributeError() + if attr == 'planner': + return self._planner + if attr == 'set_vf': + return self.set_vf + return getattr(self._wrapped_env, attr) + # try: + # getattr(self, attr) + # except Exception: + # return getattr(self._wrapped_env, attr) + + def __getstate__(self): + """ + This is useful to override in case the wrapped env has some funky + __getstate__ that doesn't play well with overriding __getattr__. + + The main problematic case is/was gym's EzPickle serialization scheme. + :return: + """ + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + + def __str__(self): + return '{}({})'.format(type(self).__name__, self.wrapped_env) + +class GripperCloseEnv(ProxyEnv): + def __init__( + self, + env, + ): + ProxyEnv.__init__(self, env) + ub = self._wrapped_env.action_space + assert ub.shape == (7,) + self.action_space = Box(ub.low[:6], ub.high[:6]) + self.observation_space = spaces.Dict( + { + "state_observation": spaces.Dict( + { + "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(6,)), # xyz + euler + "tcp_vel": spaces.Box(-np.inf, np.inf, shape=(6,)), + "tcp_force": spaces.Box(-np.inf, np.inf, shape=(3,)), + "tcp_torque": spaces.Box(-np.inf, np.inf, shape=(3,)), + } + ), + "image_observation": spaces.Dict( + { + "wrist_1": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_1_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + "wrist_2": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_2_full": spaces.Box(0, 255, shape=(480, 640, 3), dtype=np.uint8), + } + ), + } + ) + + def step(self, action): + a = np.zeros(self._wrapped_env.action_space.shape) + a[:6] = copy.deepcopy(action) + a[6] = 1 + return self._wrapped_env.step(a) + +class SpacemouseIntervention(ProxyEnv): + def __init__(self, env, gripper_enabled=False): + ProxyEnv.__init__(self, env) + self._wrapped_env = env + self.action_space = self._wrapped_env.action_space + self.gripper_enabled = gripper_enabled + if self.gripper_enabled: + assert self.action_space.shape == (7,) # maybe not so elegant + self.observation_space = self._wrapped_env.observation_space + self.expert = SpaceMouseExpert( + xyz_dims=3, + xyz_remap=[0, 1, 2], + xyz_scale=200, + rot_scale=200, + all_angles=True + ) + self.last_intervene = 0 + + def expert_action(self, action): + ''' + Input: + - action: policy action + Output: + - action: spacemouse action if nonezero; else, policy action + ''' + controller_a, _, left, right = self.expert.get_action() + expert_a = np.zeros((6,)) + if self.gripper_enabled: + expert_a = np.zeros((7,)) + expert_a[-1] = np.random.uniform(-1, 0) + + expert_a[:3] = controller_a[:3] # XYZ + expert_a[3] = controller_a[4] # Roll + expert_a[4] = controller_a[5] # Pitch + expert_a[5] = -controller_a[6] # Yaw + + if self.gripper_enabled: + if left: + expert_a[6] = np.random.uniform(0, 1) + self.last_intervene = time.time() + + if np.linalg.norm(expert_a[:6]) > 0.001: + self.last_intervene = time.time() + else: + if np.linalg.norm(expert_a) > 0.001: + self.last_intervene = time.time() + + if time.time() - self.last_intervene < 0.5: + return expert_a, left, right + return action, left, right + + def step(self, action): + expert_action, left, right = self.expert_action(action) + o, r, done, truncated, env_info = self._wrapped_env.step(expert_action) + env_info['expert_action'] = expert_action + env_info['right'] = right + return o, r, done, truncated, env_info + +class FourDoFWrapper(gym.ActionWrapper): + def __init__(self, env: Env): + super().__init__(env) + + def action(self, action): + a = np.zeros(4) + a[:3] = action[:3] + a[-1] = action[-1] + return a + +class TwoCameraFrankaWrapper(gym.ObservationWrapper): + def __init__(self, env): + ProxyEnv.__init__(self, env) + self.env = env + self.observation_space = spaces.Dict( + { + "state": spaces.flatten_space(self.env.observation_space['state_observation']), + "wrist_1": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + "wrist_2": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + # "side_1": spaces.Box(0, 255, shape=(128, 128, 3), dtype=np.uint8), + } + ) + + def observation(self, obs): + ob = { + 'state': spaces.flatten(self.env.observation_space['state_observation'], + obs['state_observation']), + 'wrist_1': obs['image_observation']['wrist_1'][...,::-1], # flip color channel + 'wrist_2': obs['image_observation']['wrist_2'][...,::-1], # flip color channel + # 'side_1': obs['image_observation']['side_1'][...,::-1], # flip color channel + } + return ob + +class ResetFreeWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.task_id = 0 # 0: place into silver bin, 1: place into brown bin + + def reset(self, task_id=0): + self.task_id = task_id + print(f'reset to task {self.task_id}') + if self.task_id == 0: + self.resetpos[1] = self.centerpos[1] + 0.1 + else: + self.resetpos[1] = self.centerpos[1] - 0.1 + return self.env.reset() \ No newline at end of file