Skip to content

Commit

Permalink
Merge pull request #281 from leggedrobotics/dev/extract_traj
Browse files Browse the repository at this point in the history
Dev/extract traj
  • Loading branch information
Robin271828 authored Sep 12, 2023
2 parents e7e5e69 + c0b143e commit 5798f0b
Show file tree
Hide file tree
Showing 11 changed files with 543 additions and 134 deletions.
190 changes: 190 additions & 0 deletions scripts/dataset_generation/extract_binary_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import sys
import os
from pathlib import Path
import time
from tqdm import tqdm
import subprocess
import yaml

from tf_bag import BagTfTransformer
import rospy
import rosparam
from sensor_msgs.msg import Image, CameraInfo, CompressedImage
import rosbag

from postprocessing_tools_ros.merging import merge_bags_single, merge_bags_all

# from py_image_proc_cuda import ImageProcCuda
# from cv_bridge import CvBridge

from wild_visual_navigation import WVN_ROOT_DIR
from wild_visual_navigation.utils import perugia_dataset, ROOT_DIR

sys.path.append(f"{WVN_ROOT_DIR}/wild_visual_navigation_ros/scripts")
from wild_visual_navigation_node import WvnRosInterface

sys.path.append(f"{WVN_ROOT_DIR}/wild_visual_navigation_anymal/scripts")
from anymal_msg_converter_node import anymal_msg_callback

# We need to do the following
# 1. Debayering cam4 -> send via ros and wait for result ("correct params")
# 2. anymal_state_topic -> /wild_visual_navigation_node/robot_state
# 3. Feed into wild_visual_navigation_node ("correct params")
# # Iterate rosbags


def get_bag_info(rosbag_path: str) -> dict:
# This queries rosbag info using subprocess and get the YAML output to parse the topics
info_dict = yaml.safe_load(
subprocess.Popen(["rosbag", "info", "--yaml", rosbag_path], stdout=subprocess.PIPE).communicate()[0]
)
return info_dict


class BagTfTransformerWrapper:
def __init__(self, bag):
self.tf_listener = BagTfTransformer(bag)

def waitForTransform(self, parent_frame, child_frame, time, duration):
return self.tf_listener.waitForTransform(parent_frame, child_frame, time)

def lookupTransform(self, parent_frame, child_frame, time):
try:
return self.tf_listener.lookupTransform(parent_frame, child_frame, time)
except:
return (None, None)


def do(n, dry_run):
d = perugia_dataset[n]

if bool(dry_run):
print(d)
return

s = os.path.join(ROOT_DIR, d["name"])

valid_topics = ["/state_estimator/anymal_state", "/wide_angle_camera_front/img_out"]

rosbags = ["/home/rschmid/RosBags/6_proc/images.bag",
"/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_0.bag",
"/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_1.bag"]

output_bag_wvn = s + "_wvn.bag"
output_bag_tf = s + "_tf.bag"

if not os.path.exists(output_bag_tf):
total_included_count, total_skipped_count = merge_bags_single(
input_bag=rosbags, output_bag=output_bag_tf, topics="/tf /tf_static", verbose=True
)
if not os.path.exists(output_bag_wvn):
total_included_count, total_skipped_count = merge_bags_single(
input_bag=rosbags, output_bag=output_bag_wvn, topics=" ".join(valid_topics), verbose=True
)

# Setup WVN node
rospy.init_node("wild_visual_navigation_node")

mission = s.split("/")[-1]

running_store_folder = f"/home/rschmid/RosBags/output/{mission}"

if os.path.exists(running_store_folder):
print("Folder already exists, but proceeding!")
# return

rosparam.set_param("wild_visual_navigation_node/mode", "extract_labels")
rosparam.set_param("wild_visual_navigation_node/running_store_folder", running_store_folder)

# for proprioceptive callback
state_msg_valid = False
desired_twist_msg_valid = False

wvn_ros_interface = WvnRosInterface()
print("-" * 80)

print("start loading tf")
tf_listener = BagTfTransformerWrapper(output_bag_tf)
wvn_ros_interface.setup_rosbag_replay(tf_listener)
print("done loading tf")

# Höngg new
info_msg = CameraInfo()
info_msg.height = 1080
info_msg.width = 1440
info_msg.distortion_model = "equidistant"
info_msg.D = [0.4316922809468283, 0.09279900476637248, -0.4010909691803734, 0.4756163338479413]
info_msg.K = [575.6050407221768, 0.0, 745.7312198525915, 0.0, 578.564849365178, 519.5207040671075, 0.0, 0.0, 1.0]
info_msg.P = [575.6050407221768, 0.0, 745.7312198525915, 0.0, 0.0, 578.564849365178, 519.5207040671075, 0.0, 0.0, 0.0, 1.0, 0.0]
info_msg.R = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]

rosbag_info_dict = get_bag_info(output_bag_wvn)
total_msgs = sum([x["messages"] for x in rosbag_info_dict["topics"] if x["topic"] in valid_topics])
total_time_img = 0
total_time_state = 0
n = 0

with rosbag.Bag(output_bag_wvn, "r") as bag:
if rospy.is_shutdown():
return

start_time = rospy.Time.from_sec(bag.get_start_time() + d["start"])
end_time = rospy.Time.from_sec(bag.get_start_time() + d["stop"])
with tqdm(
total=total_msgs,
desc="Total",
colour="green",
position=1,
bar_format="{desc:<13}{percentage:3.0f}%|{bar:20}{r_bar}",
) as pbar:
for (topic, msg, ts) in bag.read_messages(topics=None, start_time=start_time, end_time=end_time):

if rospy.is_shutdown():
return

pbar.update(1)
st = time.time()
if topic == "/state_estimator/anymal_state":
state_msg = anymal_msg_callback(msg, return_msg=True)
state_msg_valid = True

elif topic == "/wide_angle_camera_front/img_out":
image_msg = msg
# print("Received /wide_angle_camera_front/img_out")

info_msg.header = msg.header
camera_options = {}
camera_options['name'] = "wide_angle_camera_front"
camera_options["use_for_training"] = True

info_msg.header = msg.header
try:
wvn_ros_interface.image_callback(image_msg, info_msg, camera_options)
except Exception as e:
print("Bad image_callback", e)

total_time_img += time.time() - st
# print(f"image time: {total_time_img} , state time: {total_time_state}")
# print("add image")
if state_msg_valid:
try:
wvn_ros_interface.robot_state_callback(state_msg, None)
except Exception as e:
print("Bad robot_state callback ", e)

state_msg_valid = False
total_time_state += time.time() - st

print("Finished with converting the dataset")
rospy.signal_shutdown("stop the node")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=0, help="Store data")
parser.add_argument("--dry_run", type=int, default=0, help="Store data")
args = parser.parse_args()

do(args.n, args.dry_run)
46 changes: 46 additions & 0 deletions wild_visual_navigation/image_projector/image_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,52 @@ def project_and_render(

return self.masks, image_overlay, projected_points, valid_points

def project_and_render_on_map(
self, pose_base_in_world: torch.tensor, points: torch.tensor, colors: torch.tensor, map_resolution: float, map_size: int, image: torch.tensor = None
):
"""Projects the points and returns an image with the projection
Args:
points: (torch.Tensor, dtype=torch.float32, shape=(B, N, 3)): B batches, of N input points in 3D space
colors: (torch.Tensor, rtype=torch.float32, shape=(B, 3))
Returns:
out_img (torch.tensor, dtype=torch.int64): Image with projected points
"""

# self.masks = self.masks * 0.0
B = self.camera.batch_size
C = 3 # RGB channel output
H = self.camera.height.item()
W = self.camera.width.item()
self.masks = torch.zeros((B, C, H, W), dtype=torch.float32, device=self.camera.camera_matrix.device)
image_overlay = image

T_BW = pose_base_in_world.inverse()
# Convert from fixed to base frame
points_B = transform_points(T_BW, points)

# Remove z dimension
# TODO: project footprint on gravity aligned plane
flat_points = points_B[:, :, :-1]

# Shift to grid map coordinates
flat_points = flat_points / map_resolution + map_size / 2

# Fill the mask
self.masks = draw_convex_polygon(self.masks, flat_points, colors)

# Draw on image (if applies)
if image is not None:
if len(image.shape) != 4:
image = image[None]
image_overlay = draw_convex_polygon(image, flat_points, colors)

# Return torch masks
self.masks[self.masks == 0.0] = torch.nan

return self.masks, image_overlay

def resize_image(self, image: torch.tensor):
return self.image_crop(image)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
feature_type: str = "dino",
min_samples_for_training: int = 10,
vis_node_index: int = 10,
map_resolution: float = 0.1,
map_size: int = 128,
mode: bool = False,
extraction_store_folder=None,
anomaly_detection: bool = False,
Expand All @@ -62,6 +64,8 @@ def __init__(
self._scale_traversability = scale_traversability
self._params = params
self._scale_traversability_threshold = 0
self._map_resolution = map_resolution
self._map_size = map_size
self._anomaly_detection = anomaly_detection

if self._scale_traversability:
Expand Down Expand Up @@ -306,7 +310,7 @@ def add_mission_node(self, node: MissionNode, verbose: bool = False, update_feat

@accumulate_time
@torch.no_grad()
def add_proprio_node(self, pnode: ProprioceptionNode):
def add_proprio_node(self, pnode: ProprioceptionNode, projection_mode: str = "image"):
"""Adds a node to the proprioceptive graph to store proprioception
Args:
Expand Down Expand Up @@ -356,38 +360,48 @@ def add_proprio_node(self, pnode: ProprioceptionNode):
color = torch.ones((3,), device=self._device)

# New implementation
B = len(mission_nodes)
B = len(mission_nodes) # Number of mission nodes to project

# Prepare batches
K = torch.eye(4, device=self._device).repeat(B, 1, 1)
supervision_masks = torch.zeros(last_mission_node.supervision_mask.shape, device=self._device).repeat(
B, 1, 1, 1
)
pose_camera_in_world = torch.eye(4, device=self._device).repeat(B, 1, 1)
pose_base_in_world = torch.eye(4, device=self._device).repeat(B, 1, 1)

H = last_mission_node.image_projector.camera.height
W = last_mission_node.image_projector.camera.width
footprints = footprint.repeat(B, 1, 1)

for i, mnode in enumerate(mission_nodes):
K[i] = mnode.image_projector.camera.intrinsics
pose_camera_in_world[i] = mnode.pose_cam_in_world
pose_base_in_world[i] = mnode.pose_base_in_world

if (not hasattr(mnode, "supervision_mask")) or (mnode.supervision_mask is None):
continue
else:
supervision_masks[i] = mnode.supervision_mask
supervision_masks[i] = mnode.supervision_mask # Getting all the existing supervision masks

im = ImageProjector(K, H, W)
mask, _, _, _ = im.project_and_render(pose_camera_in_world, footprints, color)

map_resolution = self._map_resolution
map_size = self._map_size

if projection_mode == "image":
mask, _, _, _ = im.project_and_render(pose_camera_in_world, footprints, color) # Generating the new supervisiom mask to add
elif projection_mode == "map":
mask, _ = im.project_and_render_on_map(pose_base_in_world, footprints, color, map_resolution, map_size)

# Update traversability
mask = mask * pnode.traversability
supervision_masks = torch.fmin(supervision_masks, mask)
# mask = mask * pnode.traversability
supervision_masks = torch.fmin(supervision_masks, mask) # Adding the new mask to the supervision mask, using element-wise non-nan values

# Update supervision mask per node
for i, mnode in enumerate(mission_nodes):
mnode.supervision_mask = supervision_masks[i]
mnode.update_supervision_signal()
# mnode.update_supervision_signal() # Accumulate supervision signal, check if features are there

if self._mode == WVNMode.EXTRACT_LABELS:
p = os.path.join(
Expand All @@ -398,6 +412,7 @@ def add_proprio_node(self, pnode: ProprioceptionNode):
store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
torch.save(store, p)


# if self._anomaly_detection:
# # Visualize supervision mask
# store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0
Expand Down
Loading

0 comments on commit 5798f0b

Please sign in to comment.