Skip to content

Commit

Permalink
Disable pesky mean warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexanders101 committed Jul 12, 2023
1 parent ad35557 commit 8c2605b
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions spanet/network/jet_reconstruction/jet_reconstruction_validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Callable
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -62,13 +63,16 @@ def compute_metrics(self, jet_predictions, particle_scores, stacked_targets, sta
particle_accuracies = particle_accuracies.max(0)

# Create the logging dictionaries
metrics = {f"jet/accuracy_{i}_of_{j}": (jet_accuracies[num_particles == j] >= i).mean()
for j in range(1, num_targets + 1)
for i in range(1, j + 1)}

metrics.update({f"particle/accuracy_{i}_of_{j}": (particle_accuracies[num_particles == j] >= i).mean()
for j in range(1, num_targets + 1)
for i in range(1, j + 1)})
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)

metrics = {f"jet/accuracy_{i}_of_{j}": (jet_accuracies[num_particles == j] >= i).mean()
for j in range(1, num_targets + 1)
for i in range(1, j + 1)}

metrics.update({f"particle/accuracy_{i}_of_{j}": (particle_accuracies[num_particles == j] >= i).mean()
for j in range(1, num_targets + 1)
for i in range(1, j + 1)})

particle_scores = particle_scores.ravel()
particle_targets = permuted_masks.ravel()
Expand Down

0 comments on commit 8c2605b

Please sign in to comment.