From 8696d565e280415f71a414be082b3c78ea101533 Mon Sep 17 00:00:00 2001 From: RobinSchmid7 Date: Tue, 12 Sep 2023 18:10:52 +0200 Subject: [PATCH] Add param file for extraction --- scripts/dataset_generation/extract_binary_maps.py | 14 ++++++++++---- wild_visual_navigation/cfg/__init__.py | 1 + wild_visual_navigation/cfg/extraction_params.py | 11 +++++++++++ .../traversability_estimator.py | 9 ++++++++- 4 files changed, 30 insertions(+), 5 deletions(-) create mode 100644 wild_visual_navigation/cfg/extraction_params.py diff --git a/scripts/dataset_generation/extract_binary_maps.py b/scripts/dataset_generation/extract_binary_maps.py index 9c0eec1a..0392ad8d 100644 --- a/scripts/dataset_generation/extract_binary_maps.py +++ b/scripts/dataset_generation/extract_binary_maps.py @@ -11,6 +11,7 @@ import rosparam from sensor_msgs.msg import Image, CameraInfo, CompressedImage import rosbag +from wild_visual_navigation.cfg import ExtractionParams from postprocessing_tools_ros.merging import merge_bags_single, merge_bags_all @@ -64,11 +65,16 @@ def do(n, dry_run): s = os.path.join(ROOT_DIR, d["name"]) - valid_topics = ["/state_estimator/anymal_state", "/wide_angle_camera_front/img_out"] + extraction_cfg = ExtractionParams() - rosbags = ["/home/rschmid/RosBags/6_proc/images.bag", - "/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_0.bag", - "/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_1.bag"] + valid_topics = extraction_cfg.wvn_topics + rosbags = extraction_cfg.wvn_bags + + # valid_topics = ["/state_estimator/anymal_state", "/wide_angle_camera_front/img_out", "/depth_camera_front_upper/point_cloud_self_filtered"] + # + # rosbags = ["/home/rschmid/RosBags/6_proc/images.bag", + # "/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_0.bag", + # "/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_1.bag"] output_bag_wvn = s + "_wvn.bag" output_bag_tf = s + "_tf.bag" diff --git a/wild_visual_navigation/cfg/__init__.py b/wild_visual_navigation/cfg/__init__.py index 829340ea..c732142a 100644 --- a/wild_visual_navigation/cfg/__init__.py +++ b/wild_visual_navigation/cfg/__init__.py @@ -1 +1,2 @@ from .experiment_params import ExperimentParams +from .extraction_params import ExtractionParams diff --git a/wild_visual_navigation/cfg/extraction_params.py b/wild_visual_navigation/cfg/extraction_params.py new file mode 100644 index 00000000..4126bcde --- /dev/null +++ b/wild_visual_navigation/cfg/extraction_params.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass, field, asdict +from typing import Tuple, Dict, List, Optional, Any + +@dataclass +class ExtractionParams: + wvn_topics: List[str] = field(default_factory=lambda: ["/state_estimator/anymal_state", "/wide_angle_camera_front/img_out", "/depth_camera_front_upper/point_cloud_self_filtered"]) + wvn_bags: List[str] = field(default_factory=lambda: ["/home/rschmid/RosBags/6_proc/images.bag", + "/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_0.bag", + "/home/rschmid/RosBags/6_proc/2023-03-02-11-13-08_anymal-d020-lpc_mission_1.bag"]) + +data: ExtractionParams = ExtractionParams() \ No newline at end of file diff --git a/wild_visual_navigation/traversability_estimator/traversability_estimator.py b/wild_visual_navigation/traversability_estimator/traversability_estimator.py index edc2070f..1d39ae12 100644 --- a/wild_visual_navigation/traversability_estimator/traversability_estimator.py +++ b/wild_visual_navigation/traversability_estimator/traversability_estimator.py @@ -406,12 +406,19 @@ def add_proprio_node(self, pnode: ProprioceptionNode, projection_mode: str = "im if self._mode == WVNMode.EXTRACT_LABELS: p = os.path.join( self._extraction_store_folder, - "supervision_mask", + "data_mask", str(mnode.timestamp).replace(".", "_") + ".pt", ) store = torch.nan_to_num(mnode.supervision_mask.nanmean(axis=0)) != 0 torch.save(store, p) + if self._mode == WVNMode.EXTRACT_LABELS: + p = os.path.join( + self._extraction_store_folder, + "data_image", + str(mnode.timestamp).replace(".", "_") + ".pt", + ) + torch.save(mnode.image, p) # if self._anomaly_detection: # # Visualize supervision mask