Skip to content

Commit

Permalink
anomaly_detection_running
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 18, 2024
1 parent 35c8fdb commit f4d35e4
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 90 deletions.
2 changes: 1 addition & 1 deletion tests/test_confidence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_confidence_generator():
conf = torch.zeros((3, N), device=device)

# Naive confidence generator
cg = ConfidenceGenerator(std_factor=sigma_factor).to(device)
cg = ConfidenceGenerator(std_factor=sigma_factor, method="latest_measurement").to(device)

# Run
for i in range(x.shape[0]):
Expand Down
2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class LossParams:

@dataclass
class LossAnomalyParams:
method: str = "running_mean" # "latest_measurement", "running_mean", "moving_average"
method: str = "latest_measurement"
confidence_std_factor: float = 0.5

loss_anomaly: LossAnomalyParams = LossAnomalyParams()
Expand Down
48 changes: 16 additions & 32 deletions wild_visual_navigation/utils/confidence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
class ConfidenceGenerator(torch.nn.Module):
def __init__(
self,
std_factor: float = 0.7,
method: str = "running_mean",
std_factor,
method,
log_enabled: bool = False,
log_folder: str = f"{WVN_ROOT_DIR}/results",
anomaly_detection: bool = False,
):
"""Returns a confidence value for each number
Expand All @@ -29,7 +28,6 @@ def __init__(

self.log_enabled = log_enabled
self.log_folder = log_folder
self.anomaly_detection = anomaly_detection

mean = torch.zeros(1, dtype=torch.float32)
var = torch.ones((1, 1), dtype=torch.float32)
Expand Down Expand Up @@ -107,19 +105,12 @@ def update_running_mean(self, x: torch.tensor, x_positive: torch.tensor):
if x.device != self.mean.device:
return torch.zeros_like(x)

if self.anomaly_detection:
x = torch.clip(x, self.mean - 2 * self.std, self.mean + 2 * self.std)
confidence = (x - torch.min(x)) / (torch.max(x) - torch.min(x))
else:
# Then the confidence is computed as the distance to the center of the Gaussian given factor*sigma
# confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5)
# confidence[x < self.mean] = 1.0

shifted_mean = self.mean + self.std * self.std_factor
interval_min = shifted_mean - 2 * self.std
interval_max = shifted_mean + 2 * self.std
x = torch.clip(x, interval_min, interval_max)
confidence = 1 - ((x - interval_min) / (interval_max - interval_min))
shifted_mean = self.mean + self.std * self.std_factor
std_fac = 1
interval_min = max(shifted_mean - std_fac * self.std, 0)
interval_max = shifted_mean + std_fac * self.std
x = torch.clip(x, interval_min, interval_max)
confidence = 1 - ((x - interval_min) / (interval_max - interval_min))

return confidence.type(torch.float32)

Expand Down Expand Up @@ -189,22 +180,15 @@ def update(
return output

def inference_without_update(self, x: torch.tensor):

if x.device != self.mean.device:
return torch.zeros_like(x)

if self.anomaly_detection:
x = torch.clip(x, self.mean - 2 * self.std, self.mean + 2 * self.std)
confidence = (x - torch.min(x)) / (torch.max(x) - torch.min(x))

else:
shifted_mean = self.mean + self.std * self.std_factor
interval_min = shifted_mean - 2 * self.std
interval_max = shifted_mean + 2 * self.std
x = torch.clip(x, interval_min, interval_max)
confidence = 1 - ((x - interval_min) / (interval_max - interval_min))

# confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5)
# confidence[x < self.mean] = 1.0
shifted_mean = self.mean + self.std * self.std_factor
std_fac = 1
interval_min = max(shifted_mean - std_fac * self.std, 0)
interval_max = shifted_mean + std_fac * self.std
x = torch.clip(x, interval_min, interval_max)
confidence = 1 - ((x - interval_min) / (interval_max - interval_min))

return confidence.type(torch.float32)

Expand All @@ -229,7 +213,7 @@ def get_dict(self):


if __name__ == "__main__":
cg = ConfidenceGenerator()
cg = ConfidenceGenerator(std_factor=0.5, method="latest_measurement")
for i in range(1000):
inp = torch.rand(10) * 10
res = cg.update(inp, inp, step=i)
Expand Down
16 changes: 5 additions & 11 deletions wild_visual_navigation/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ def __init__(
super(AnomalyLoss, self).__init__()

self._confidence_generator = ConfidenceGenerator(
std_factor=confidence_std_factor,
method=method,
log_enabled=log_enabled,
log_folder=log_folder,
anomaly_detection=True,
std_factor=confidence_std_factor, method=method, log_enabled=log_enabled, log_folder=log_folder
)

def forward(
Expand All @@ -46,7 +42,9 @@ def forward(
losses = res["logprob"].sum(1) + res["log_det"] # Sum over all channels, resulting in h*w output dimensions

if update_generator:
confidence = self._confidence_generator.update(x=losses, x_positive=losses, step=step)
confidence = self._confidence_generator.update(
x=-losses.clone().detach(), x_positive=-losses.clone().detach(), step=step
)

loss_aux["confidence"] = confidence

Expand Down Expand Up @@ -85,11 +83,7 @@ def __init__(
self._trav_loss_func = F.mse_loss

self._confidence_generator = ConfidenceGenerator(
std_factor=confidence_std_factor,
method=method,
log_enabled=log_enabled,
log_folder=log_folder,
anomaly_detection=False,
std_factor=confidence_std_factor, method=method, log_enabled=log_enabled, log_folder=log_folder
)

def reset(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ feature_type: "stego" # Options: dino, dinov2, stego
dino_patch_size: 8 # Options: 8, 16; We found 8 is sufficient and faster
dino_backbone: vit_small # Options: vit_small
slic_num_components: 100 # Number of segments for slic
confidence_std_factor: 4.0 # Tuning parameter to change the confidence computation / explained in confidence_generator.py
confidence_std_factor: 1.0 # Tuning parameter to change the confidence computation / explained in confidence_generator.py
min_samples_for_training: 5 # Minimum number of mission nodes with successfull reprojection before start training
prediction_per_pixel: True # If true, trained network is inferenced for each pixel, otherwise per segment
vis_node_index: 10 # Defines node which is used for visualization, can help to debug the projected footprint
Expand All @@ -41,7 +41,7 @@ supervision_callback_rate: 10 # Maximum rate at which supervision_signals
learning_thread_rate: 10 # Gradient steps per second in hertz
logging_thread_rate: 2 # Logging of learning_node in hertz
status_thread_rate: 0.5 # Status of feature_extractor_node in hertz
load_save_checkpoint_rate: 0.2 # Rate at which checkpoints are saved / loaded between learning and feature extraction
load_save_checkpoint_rate: 1.0 # Rate at which checkpoints are saved / loaded between learning and feature extraction


# Runtime options
Expand Down
50 changes: 28 additions & 22 deletions wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,15 @@ def __init__(self, node_name):
self._model = get_model(self._params.model).to(self._ros_params.device)
self._model.eval()

if not self.anomaly_detection:
if self.anomaly_detection:
self._confidence_generator = ConfidenceGenerator(
method=self._params.loss.method,
std_factor=self._params.loss.confidence_std_factor,
anomaly_detection=self.anomaly_detection,
method=self._params.loss_anomaly.method, std_factor=self._params.loss_anomaly.confidence_std_factor
)

else:
self._anomaly_loss = AnomalyLoss(
**self._params.loss_anomaly,
log_enabled=self._params.general.log_confidence,
log_folder=self._params.general.model_path,
self._confidence_generator = ConfidenceGenerator(
method=self._params.loss.method, std_factor=self._params.loss.confidence_std_factor
)
self._anomaly_loss.to(self._ros_params.device)

self._log_data = {}
self.setup_ros()

Expand Down Expand Up @@ -116,6 +111,7 @@ def read_params(self):

with read_write(self._params):
self._params.loss.confidence_std_factor = self._ros_params.confidence_std_factor
self._params.loss_anomaly.confidence_std_factor = self._ros_params.confidence_std_factor

self.anomaly_detection = self._params.model.name == "LinearRnvp"

Expand Down Expand Up @@ -335,7 +331,9 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
if not self.anomaly_detection:
out_trav = prediction.reshape(H, W, -1)[:, :, 0]
else:
loss, loss_aux, trav = self._anomaly_loss(None, prediction)
losses = prediction["logprob"].sum(1) + prediction["log_det"]
confidence = self._confidence_generator.inference_without_update(x=-losses)
trav = confidence
out_trav = trav.reshape(H, W, -1)[:, :, 0]

msg = rc.numpy_to_ros_image(out_trav.cpu().numpy(), "passthrough")
Expand Down Expand Up @@ -426,18 +424,26 @@ def load_model(self, stamp):
new_model_state_dict = torch.load(p)
k = list(self._model.state_dict().keys())[-1]

if (self._model.state_dict()[k] != new_model_state_dict[k]).any():
if self._ros_params.verbose:
self._log_data[f"time_last_model"] = rospy.get_time()
self._log_data[f"nr_model_updates"] += 1

self._model.load_state_dict(new_model_state_dict, strict=False)
# check if the key is in state dict - this may be not the case if switched between models
# assumption first key within state_dict is unique and sufficient to identify if a model has changed
if k in new_model_state_dict:
# check if the model has changed
if (self._model.state_dict()[k] != new_model_state_dict[k]).any():
if self._ros_params.verbose:
self._log_data[f"time_last_model"] = rospy.get_time()
self._log_data[f"nr_model_updates"] += 1

self._model.load_state_dict(new_model_state_dict, strict=False)
if "confidence_generator" in new_model_state_dict.keys():
cg = new_model_state_dict["confidence_generator"]
self._confidence_generator.var = cg["var"]
self._confidence_generator.mean = cg["mean"]
self._confidence_generator.std = cg["std"]

if self._ros_params.verbose:
m, s, v = cg["mean"].item(), cg["std"].item(), cg["var"].item()
rospy.loginfo(f"[{self._node_name}] Loaded Confidence Generator {m}, std {s} var {v}")

if "confidence_generator" in new_model_state_dict.keys():
cg = new_model_state_dict["confidence_generator"]
self._confidence_generator.var = cg["var"]
self._confidence_generator.mean = cg["mean"]
self._confidence_generator.std = cg["std"]
else:
if self._ros_params.verbose:
rospy.logerr(f"[{self._node_name}] Model Loading Failed")
Expand Down
Loading

0 comments on commit f4d35e4

Please sign in to comment.