From d33f6109355669308d554f43c1644a8497d28d6b Mon Sep 17 00:00:00 2001 From: RobinSchmid7 Date: Tue, 29 Aug 2023 16:18:05 +0200 Subject: [PATCH] Added comments, changes to default rnvp config --- wild_visual_navigation/cfg/experiment_params.py | 2 +- .../learning/model/linear_rnvp.py | 15 ++++++++------- wild_visual_navigation/learning/utils/loss.py | 2 +- .../traversability_estimator.py | 5 +---- .../config/wild_visual_navigation/default.yaml | 8 ++++---- .../inputs/wide_angle_front.yaml | 4 ++-- .../inputs/wide_angle_front_compressed.yaml | 4 ++-- .../scripts/wvn_feature_extractor_node.py | 4 ++-- 8 files changed, 21 insertions(+), 23 deletions(-) diff --git a/wild_visual_navigation/cfg/experiment_params.py b/wild_visual_navigation/cfg/experiment_params.py index ee6c17f0..b7d3a0a3 100644 --- a/wild_visual_navigation/cfg/experiment_params.py +++ b/wild_visual_navigation/cfg/experiment_params.py @@ -124,7 +124,7 @@ class LinearRnvpCfgParams: mask_type: str = "odds" conditioning_size: int = 0 use_permutation: bool = True - single_function: bool = True + single_function: bool = False linear_rnvp_cfg: LinearRnvpCfgParams = LinearRnvpCfgParams() diff --git a/wild_visual_navigation/learning/model/linear_rnvp.py b/wild_visual_navigation/learning/model/linear_rnvp.py index 89125307..e2a3e1d4 100644 --- a/wild_visual_navigation/learning/model/linear_rnvp.py +++ b/wild_visual_navigation/learning/model/linear_rnvp.py @@ -114,7 +114,7 @@ def backward(self, x, y=None): _mx = mx s, t = self.st(_mx) - s = torch.tanh(s) + s = torch.tanh(s) # Adding an activation function here with non-linearities u = mx + (1 - self.mask) * (x - t) * torch.exp(-s) @@ -221,8 +221,8 @@ def __init__( ): super().__init__() - self.register_buffer("prior_mean", torch.zeros(input_dim)) - self.register_buffer("prior_var", torch.ones(input_dim)) + self.register_buffer("prior_mean", torch.zeros(input_dim)) # Normal Gaussian with zero mean + self.register_buffer("prior_var", torch.ones(input_dim)) # Normal Gaussian with unit variance if mask_type == "odds": mask = torch.arange(0, input_dim).float() % 2 @@ -258,16 +258,17 @@ def __init__( self.flows = SequentialFlow(*blocks) def logprob(self, x): - return self.prior.log_prob(x) + return self.prior.log_prob(x) # Compute log probability of the input at the Gaussian distribution @property def prior(self): - return distributions.Normal(self.prior_mean, self.prior_var) + return distributions.Normal(self.prior_mean, self.prior_var) # Normal Gaussian with zero mean and unit variance def forward(self, data: Data): x = data.x - z, log_det = self.flows.forward(x, None) - return {"z": z, "log_det": log_det, "logprob": self.logprob(z)} + z, log_det = self.flows.forward(x, y=None) + log_prob = self.logprob(z) + return {"z": z, "log_det": log_det, "logprob": log_prob} def backward(self, u, y=None, return_step=False): if return_step: diff --git a/wild_visual_navigation/learning/utils/loss.py b/wild_visual_navigation/learning/utils/loss.py index 079e7403..cb21905a 100644 --- a/wild_visual_navigation/learning/utils/loss.py +++ b/wild_visual_navigation/learning/utils/loss.py @@ -25,7 +25,7 @@ def forward( loss_aux["loss_reco"] = torch.tensor([0.0]) loss_aux["confidence"] = torch.tensor([0.0]) - losses = -(res["logprob"].sum(1) + res["log_det"]) + losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions # print(torch.mean(losses)) l_clip = losses diff --git a/wild_visual_navigation/traversability_estimator/traversability_estimator.py b/wild_visual_navigation/traversability_estimator/traversability_estimator.py index 84fa7873..695f5816 100644 --- a/wild_visual_navigation/traversability_estimator/traversability_estimator.py +++ b/wild_visual_navigation/traversability_estimator/traversability_estimator.py @@ -631,10 +631,7 @@ def train(self): with self._learning_lock: # Forward pass - if self._anomaly_detection: - res = self._model(graph) - else: - res = self._model(graph) + res = self._model(graph) log_step = (self._step % 20) == 0 self._loss, loss_aux, trav = self._traversability_loss(graph, res, step=self._step, log_step=log_step, loss_mean=self._loss_mean, loss_std=self._loss_std, train=True) diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml index 09aabf7e..cd5d20f8 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml @@ -30,11 +30,11 @@ confidence_std_factor: 4.0 scale_traversability: False # This parameter needs to be false when using the anomaly detection model scale_traversability_max_fpr: 0.25 min_samples_for_training: 5 -prediction_per_pixel: False +prediction_per_pixel: false traversability_threshold: 0.55 -clip_to_binary: False -anomaly_detection: True -vis_training_samples: True +clip_to_binary: false +anomaly_detection: true +vis_training_samples: true # Supervision Generator untraversable_thr: 0.01 diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front.yaml index 50b65278..41106092 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front.yaml @@ -7,5 +7,5 @@ camera_topics: publish_input_image: true # Provides 1080 (height) x 1920 (width) images -network_input_image_height: 224 # 448 -network_input_image_width: 224 # 448 \ No newline at end of file +network_input_image_height: 448 # 448 +network_input_image_width: 448 # 448 \ No newline at end of file diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml index df07ccb8..7afa0a66 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/inputs/wide_angle_front_compressed.yaml @@ -7,5 +7,5 @@ camera_topics: publish_input_image: true # Provides 1080 (height) x 1920 (width) images -network_input_image_height: 224 # 448 -network_input_image_width: 224 # 448 \ No newline at end of file +network_input_image_height: 448 # 448 +network_input_image_width: 448 # 448 \ No newline at end of file diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index 0b2d4526..a06f5352 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -256,7 +256,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo # Evaluate traversability data = Data(x=dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1])) else: - input_feat = dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1]) + # input_feat = dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1]) input_feat = feat[seg.reshape(-1)] data = Data(x=input_feat) @@ -350,7 +350,7 @@ def load_model(self): res = torch.load(f"{WVN_ROOT_DIR}/tmp_state_dict2.pt") k = list(self.model.state_dict().keys())[-1] - if (self.model.state_dict()[k] != res[k]).any(): # TODO: model params are changing + if (self.model.state_dict()[k] != res[k]).any(): # TODO: model params are changing? if self.verbose: self.log_data[f"time_last_model"] = rospy.get_time() self.log_data[f"nr_model_updates"] += 1