Skip to content

Commit

Permalink
Added moving average
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinSchmid7 committed Sep 8, 2023
1 parent e32a91f commit ad1a69b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 4 deletions.
2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class LossParams:

@dataclass
class LossAnomalyParams:
method: str = "latest_measurment" # "latest_measurment", "running_mean"
method: str = "moving_average" # "latest_measurment", "running_mean"
confidence_std_factor: float = 0.5

loss_anomaly: LossAnomalyParams = LossAnomalyParams()
Expand Down
35 changes: 32 additions & 3 deletions wild_visual_navigation/utils/confidence_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from wild_visual_navigation.utils import KalmanFilter
import torch
import os
from collections import deque


class ConfidenceGenerator(torch.nn.Module):
Expand All @@ -26,9 +27,13 @@ def __init__(
mean = torch.zeros(1, dtype=torch.float32)
var = torch.ones((1, 1), dtype=torch.float32)
std = torch.ones(1, dtype=torch.float32)
self.mean = torch.nn.Parameter(mean, requires_grad=False)
self.var = torch.nn.Parameter(var, requires_grad=False)
self.std = torch.nn.Parameter(std, requires_grad=False)
# self.mean = torch.nn.Parameter(mean, requires_grad=False)
# self.var = torch.nn.Parameter(var, requires_grad=False)
# self.std = torch.nn.Parameter(std, requires_grad=False)

self.mean = 0
self.var = 1
self.std = 1

if method == "kalman_filter":
kf_process_cov = 0.2
Expand Down Expand Up @@ -61,6 +66,11 @@ def __init__(
elif method == "latest_measurment":
self._update = self.update_latest_measurment
self._reset = self.reset_latest_measurment
elif method == "moving_average":
window_size = 5
self.data_window = deque(maxlen=window_size)
self._update = self.update_moving_average
self._reset = self.reset_moving_average
else:
raise ValueError("Unknown method")

Expand All @@ -75,6 +85,11 @@ def reset_latest_measurment(self, x: torch.Tensor, x_positive: torch.Tensor):
self.var[0] = 1
self.std[0] = 1

def reset_moving_average(self, x: torch.Tensor, x_positive: torch.Tensor):
self.mean[0] = 0
self.var[0] = 1
self.std[0] = 1

def update_running_mean(self, x: torch.tensor, x_positive: torch.tensor):
# We assume the positive samples' loss follows a Gaussian distribution
# We estimate the parameters empirically
Expand All @@ -98,6 +113,20 @@ def update_running_mean(self, x: torch.tensor, x_positive: torch.tensor):

return confidence.type(torch.float32)

def update_moving_average(self, x: torch.tensor, x_positive: torch.tensor):
self.data_window.append(x_positive)

data_window_tensor = list(self.data_window)
data_window_tensor = torch.cat(data_window_tensor, dim=0)

self.mean = data_window_tensor.mean()
self.std = data_window_tensor.std()

x = torch.clip(x, self.mean - 2 * self.std, self.mean + 2 * self.std)
confidence = (x - torch.min(x)) / (torch.max(x) - torch.min(x))

return confidence.type(torch.float32)

def update_kalman_filter(self, x: torch.tensor, x_positive: torch.tensor):
# Kalman Filter implementation
if x_positive.shape[0] != 0:
Expand Down
63 changes: 63 additions & 0 deletions wild_visual_navigation/visu/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from wild_visual_navigation.visu import paper_colors_rgb_f, paper_colors_rgba_f
from pytictac import Timer, accumulate_time

import rospy

__all__ = ["LearningVisualizer"]


Expand Down Expand Up @@ -369,6 +371,16 @@ def plot_detectron_classification(
(seg * 255).type(torch.long).clip(0, 255), max_seg=256, colormap=cmap, store=False, not_log=True
)

# plt.hist(seg_img.ravel(), bins=500)
# # Get current ros time
# now = rospy.Time.now()
# # Create a unique filename
# filename = f"{now.secs}_{now.nsecs}.png"
# # Save the figure
# plt.savefig(f"/home/rschmid/overlays/{filename}")
# # Close the figure
# plt.close()

H, W = img.shape[:2]
back = np.zeros((H, W, 4))
back[:, :, :3] = img
Expand Down Expand Up @@ -613,6 +625,57 @@ def plot_sparse_optical_flow(
# pass
# return np.array(pil_img).astype(np.uint8)

def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
'''
Function to offset the "center" of a colormap. Useful for
data with a negative min and positive max and you want the
middle of the colormap's dynamic range to be at zero.
Input
-----
cmap : The matplotlib colormap to be altered
start : Offset from lowest point in the colormap's range.
Defaults to 0.0 (no lower offset). Should be between
0.0 and `midpoint`.
midpoint : The new center of the colormap. Defaults to
0.5 (no shift). Should be between 0.0 and 1.0. In
general, this should be 1 - vmax / (vmax + abs(vmin))
For example if your data range from -15.0 to +5.0 and
you want the center of the colormap at 0.0, `midpoint`
should be set to 1 - 5/(5 + 15)) or 0.75
stop : Offset from highest point in the colormap's range.
Defaults to 1.0 (no upper offset). Should be between
`midpoint` and 1.0.
'''
cdict = {
'red': [],
'green': [],
'blue': [],
'alpha': []
}

# regular index to compute the colors
reg_index = np.linspace(start, stop, 257)

# shifted index to match the data
shift_index = np.hstack([
np.linspace(0.0, midpoint, 128, endpoint=False),
np.linspace(midpoint, 1.0, 129, endpoint=True)
])

for ri, si in zip(reg_index, shift_index):
r, g, b, a = cmap(ri)

cdict['red'].append((si, r, r))
cdict['green'].append((si, g, g))
cdict['blue'].append((si, b, b))
cdict['alpha'].append((si, a, a))

newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
plt.register_cmap(cmap=newcmap)

return newcmap


if __name__ == "__main__":
# Data was generated in the visu function of the lightning_module with the following code
Expand Down

0 comments on commit ad1a69b

Please sign in to comment.