Skip to content

Commit

Permalink
Added comments, changes to default rnvp config
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinSchmid7 committed Aug 29, 2023
1 parent 5385241 commit d33f610
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 23 deletions.
2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class LinearRnvpCfgParams:
mask_type: str = "odds"
conditioning_size: int = 0
use_permutation: bool = True
single_function: bool = True
single_function: bool = False

linear_rnvp_cfg: LinearRnvpCfgParams = LinearRnvpCfgParams()

Expand Down
15 changes: 8 additions & 7 deletions wild_visual_navigation/learning/model/linear_rnvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def backward(self, x, y=None):
_mx = mx

s, t = self.st(_mx)
s = torch.tanh(s)
s = torch.tanh(s) # Adding an activation function here with non-linearities

u = mx + (1 - self.mask) * (x - t) * torch.exp(-s)

Expand Down Expand Up @@ -221,8 +221,8 @@ def __init__(
):
super().__init__()

self.register_buffer("prior_mean", torch.zeros(input_dim))
self.register_buffer("prior_var", torch.ones(input_dim))
self.register_buffer("prior_mean", torch.zeros(input_dim)) # Normal Gaussian with zero mean
self.register_buffer("prior_var", torch.ones(input_dim)) # Normal Gaussian with unit variance

if mask_type == "odds":
mask = torch.arange(0, input_dim).float() % 2
Expand Down Expand Up @@ -258,16 +258,17 @@ def __init__(
self.flows = SequentialFlow(*blocks)

def logprob(self, x):
return self.prior.log_prob(x)
return self.prior.log_prob(x) # Compute log probability of the input at the Gaussian distribution

@property
def prior(self):
return distributions.Normal(self.prior_mean, self.prior_var)
return distributions.Normal(self.prior_mean, self.prior_var) # Normal Gaussian with zero mean and unit variance

def forward(self, data: Data):
x = data.x
z, log_det = self.flows.forward(x, None)
return {"z": z, "log_det": log_det, "logprob": self.logprob(z)}
z, log_det = self.flows.forward(x, y=None)
log_prob = self.logprob(z)
return {"z": z, "log_det": log_det, "logprob": log_prob}

def backward(self, u, y=None, return_step=False):
if return_step:
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation/learning/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(
loss_aux["loss_reco"] = torch.tensor([0.0])
loss_aux["confidence"] = torch.tensor([0.0])

losses = -(res["logprob"].sum(1) + res["log_det"])
losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions

# print(torch.mean(losses))
l_clip = losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,7 @@ def train(self):
with self._learning_lock:
# Forward pass

if self._anomaly_detection:
res = self._model(graph)
else:
res = self._model(graph)
res = self._model(graph)

log_step = (self._step % 20) == 0
self._loss, loss_aux, trav = self._traversability_loss(graph, res, step=self._step, log_step=log_step, loss_mean=self._loss_mean, loss_std=self._loss_std, train=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ confidence_std_factor: 4.0
scale_traversability: False # This parameter needs to be false when using the anomaly detection model
scale_traversability_max_fpr: 0.25
min_samples_for_training: 5
prediction_per_pixel: False
prediction_per_pixel: false
traversability_threshold: 0.55
clip_to_binary: False
anomaly_detection: True
vis_training_samples: True
clip_to_binary: false
anomaly_detection: true
vis_training_samples: true

# Supervision Generator
untraversable_thr: 0.01
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ camera_topics:
publish_input_image: true

# Provides 1080 (height) x 1920 (width) images
network_input_image_height: 224 # 448
network_input_image_width: 224 # 448
network_input_image_height: 448 # 448
network_input_image_width: 448 # 448
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ camera_topics:
publish_input_image: true

# Provides 1080 (height) x 1920 (width) images
network_input_image_height: 224 # 448
network_input_image_width: 224 # 448
network_input_image_height: 448 # 448
network_input_image_width: 448 # 448
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
# Evaluate traversability
data = Data(x=dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1]))
else:
input_feat = dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1])
# input_feat = dense_feat[0].permute(1, 2, 0).reshape(-1, dense_feat.shape[1])
input_feat = feat[seg.reshape(-1)]
data = Data(x=input_feat)

Expand Down Expand Up @@ -350,7 +350,7 @@ def load_model(self):
res = torch.load(f"{WVN_ROOT_DIR}/tmp_state_dict2.pt")
k = list(self.model.state_dict().keys())[-1]

if (self.model.state_dict()[k] != res[k]).any(): # TODO: model params are changing
if (self.model.state_dict()[k] != res[k]).any(): # TODO: model params are changing?
if self.verbose:
self.log_data[f"time_last_model"] = rospy.get_time()
self.log_data[f"nr_model_updates"] += 1
Expand Down

0 comments on commit d33f610

Please sign in to comment.