Skip to content

Commit

Permalink
Add tests for visualize module; refactor label and prediction distrib…
Browse files Browse the repository at this point in the history
…ution plots.

Co-authored-by: Tyler Morrow <[email protected]>
  • Loading branch information
anbusto and tymorrow committed Jun 27, 2023
1 parent 57f613f commit 58aa296
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 105 deletions.
3 changes: 2 additions & 1 deletion examples/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SUCCESS_STR = "Success"
FAILURE_STR = "Fail"

original_wdir = os.getcwd()
example_dir = Path(__file__).parent
os.chdir(example_dir)

Expand Down Expand Up @@ -44,7 +45,7 @@
FILENAME_KEY: os.path.relpath(f, example_dir),
RESULT_KEY: SUCCESS_STR if not return_code else FAILURE_STR
}
os.chdir(example_dir)
os.chdir(original_wdir)

df = pd.DataFrame.from_dict(results, orient="index")
tabulated_df = tabulate(df, headers="keys", tablefmt="psql")
Expand Down
4 changes: 2 additions & 2 deletions riid/models/neural_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def fit(self, ss: SampleSet, bg_ss: SampleSet = None,

return history

def predict(self, ss: SampleSet, bg_ss: SampleSet = None):
def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False):
"""Classifies the spectra in the provided SampleSet(s).
Results are stored inside the first SampleSet's prediction-related properties.
Expand All @@ -249,7 +249,7 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None):
X = [x_test, bg_ss.get_samples().astype(float)]
else:
X = x_test
results = self.model.predict(X) # output size will be n_samples by n_labels
results = self.model.predict(X, verbose=verbose)

col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level)
col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1]
Expand Down
234 changes: 133 additions & 101 deletions riid/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def plot_live_time_vs_snr(ss: SampleSet, overlay_ss: SampleSet = None, alpha: fl
xscale: str = "linear", yscale: str = "log",
xlim: tuple = None, ylim: tuple = None,
title: str = "Live Time vs. SNR", snr_line_value: float = None,
figsize=(6.4, 4.8)):
figsize=(6.4, 4.8), target_level: str = "Isotope"):
"""Plots SNR against live time for all samples in a SampleSet.
Prediction and label information is used to distinguish between correct and incorrect
Expand All @@ -143,10 +143,10 @@ def plot_live_time_vs_snr(ss: SampleSet, overlay_ss: SampleSet = None, alpha: fl
ss: Defines a SampleSet of events to plot.
overlay_ss: Defines another SampleSet to color as black.
alpha: Defines the degree of opacity (not applied to overlay_ss scatterplot if used).
xscale: Defines the X-axis scale.
yscale: Defines the Y-axis scale.
xlim: Defines a tuple containing the X-axis min and max values.
ylim: Defines a tuple containing the Y-axis min and max values.
xscale: Defines the x-axis scale.
yscale: Defines the y-axis scale.
xlim: tuple containing the x-axis min and max values.
ylim: tuple containing the y-axis min and max values.
title: Defines the plot title.
snr_line_value: Plots a vertical line for contextualizing data to threshold
figsize: Width, height of figure in inches.
Expand All @@ -155,8 +155,8 @@ def plot_live_time_vs_snr(ss: SampleSet, overlay_ss: SampleSet = None, alpha: fl
A tuple (Figure, Axes) of matplotlib objects.
"""
labels = ss.get_labels()
predictions = ss.get_predictions()
labels = ss.get_labels(target_level=target_level)
predictions = ss.get_predictions(target_level=target_level)
correct_ss = ss[labels == predictions]
incorrect_ss = ss[labels != predictions]
if not xlim:
Expand Down Expand Up @@ -222,10 +222,10 @@ def plot_snr_vs_score(ss: SampleSet, overlay_ss: SampleSet = None, alpha: float
ss: Defines a SampleSet of events to plot.
overlay_ss: Defines another SampleSet to color as blue (correct) and/or black (incorrect).
alpha: Defines the degree of opacity (not applied to overlay_ss scatterplot if used).
xscale: Defines the X-axis scale.
yscale: Defines the Y-axis scale.
xlim: Defines a tuple containing the X-axis min and max values.
ylim: Defines a tuple containing the Y-axis min and max values.
xscale: Defines the x-axis scale.
yscale: Defines the y-axis scale.
xlim: Defines a tuple containing the x-axis min and max values.
ylim: Defines a tuple containing the y-axis min and max values.
title: Defines the plot title.
figsize: Width, height of figure in inches.
target_level: The level of the sources multi index to use for comparing correctness.
Expand Down Expand Up @@ -302,11 +302,11 @@ def plot_spectra(ss: SampleSet, in_energy: bool = False,
in_energy: Determines whether or not to try and use each spectrum's
energy calibration to interpet bins in terms of energy.
figsize: Width, height of figure in inches.
xscale: Defines the X-axis scale.
yscale: Defines the Y-axis scale.
xlim: Defines a tuple containing the X-axis min and max values.
ylim: Defines a tuple containing the Y-axis min and max values.
ylabel: Defines the Y-axis label.
xscale: Defines the x-axis scale.
yscale: Defines the y-axis scale.
xlim: Defines a tuple containing the x-axis min and max values.
ylim: Defines a tuple containing the y-axis min and max values.
ylabel: Defines the y-axis label.
title: Defines the plot title.
legend_loc: Defines the location in which to place the legend. Defaults to None.
target_level: The level of the sources multi index to use for legend labels.
Expand Down Expand Up @@ -381,10 +381,10 @@ def plot_learning_curve(train_loss: list, validation_loss: list,
Args:
train_loss: Defines a list of training loss values.
validation_loss: Defines a list of validation loss values.
xscale: Defines the X-axis scale.
yscale: Defines the Y-axis scale.
xlim: Defines a tuple containing the X-axis min and max values.
ylim: Defines a tuple containing the Y-axis min and max values.
xscale: Defines the x-axis scale.
yscale: Defines the y-axis scale.
xlim: Defines a tuple containing the x-axis min and max values.
ylim: Defines a tuple containing the y-axis min and max values.
smooth: Determines whether or not to apply smoothing to the loss curves.
title: Defines the plot title.
figsize: Width, height of figure in inches.
Expand Down Expand Up @@ -467,7 +467,7 @@ def plot_count_rate_history(cr_history: list, sample_interval: float,
pre_event_duration: Defines the time in seconds at which the anomalous source appears
(i.e., the start of the event).
validation_loss: Defines a list of validation loss values.
ylim: Defines a tuple containing the Y-axis min and max values.
ylim: Defines a tuple containing the y-axis min and max values.
title: Defines the plot title.
figsize: Width, height of figure in inches.
Expand Down Expand Up @@ -505,14 +505,21 @@ def plot_count_rate_history(cr_history: list, sample_interval: float,


@save_or_show_plot
def plot_score_histogram(ss: SampleSet, yscale="log", ylim=(1e-1, None),
title="Score Distribution", figsize=(6.4, 4.8)):
def plot_score_distribution(ss: SampleSet, bin_width=None, n_bins=100,
xscale="linear", min_bin=0.0, max_bin=1.0,
yscale="log", ylim=(1e-1, None),
title="Score Distribution", figsize=(6.4, 4.8)):
"""Plots a histogram of all of the model prediction scores.
Args:
ss: SampleSet containing prediction_probas values.
yscale: the Y-axis scale.
ylim: a tuple containing the Y-axis min and max values.
bin_width: width of each bin
n_bins: number of bins into which to bin scores.
xscale: the x-axis scale.
min_bin: min value of the bin range; also sets x-axis min.
max_bin: max value of the bin range; also sets x-axis max.
yscale: the y-axis scale.
ylim: a tuple containing the y-axis min and max values.
title: the plot title.
figsize: Width, height of figure in inches.
Expand All @@ -522,72 +529,127 @@ def plot_score_histogram(ss: SampleSet, yscale="log", ylim=(1e-1, None),
"""
fig, ax = plt.subplots(figsize=figsize)

indices1 = ss.info.index[ss.info.snr <= 5]
values1 = ss.prediction_probas.loc[indices1].values.flatten()
values1 = np.where(values1 > 0.0, values1, values1)
indices2 = ss.info.index[(ss.info.snr > 5) &
(ss.info.snr <= 50)]
values2 = ss.prediction_probas.loc[indices2].values.flatten()
values2 = np.where(values2 > 0.0, values2, values2)
indices3 = ss.info.index[ss.info.snr > 50]
values3 = ss.prediction_probas.loc[indices3].values.flatten()
values3 = np.where(values3 > 0.0, values3, values3)

width = 0.35
bins = np.linspace(0.0, 1.0, 100)
ax.bar(
values3,
bins,
width,
color="green"
)
ax.bar(
values2,
bins,
width,
color="yellow"
)
ax.bar(
values1,
bins,
width,
color="red"
)
scores = ss.prediction_probas.values.flatten()

BINS = np.linspace(min_bin, max_bin, n_bins)
ax.hist(scores, bins=BINS, rwidth=bin_width)

ax.set_xscale(xscale)
ax.set_xlim((min_bin, max_bin))
ax.set_yscale(yscale)
ax.set_ylim(ylim)
ax.set_xlabel("Scores")
ax.set_ylabel("Occurrences")
ax.set_title(title)
fig.tight_layout()

return fig, ax


def _bin_df_values_and_plot(data: pd.Series, fig, ax):
binned_labels = data.value_counts()
binned_labels.sort_index(inplace=True)
binned_labels.plot(kind='bar', subplots=True, fig=fig, ax=ax)


@save_or_show_plot
def plot_source_distribution(ss: SampleSet, figsize=(12.8, 7.2),
target_level: str = "Isotope"):
"""Plots a bar plot of number of sources for each label.
def plot_label_distribution(ss: SampleSet, ylim: tuple = (1, None),
yscale: str = "log", figsize: tuple = (12.8, 7.2),
title: str = "Label Distribution",
target_level: str = "Isotope"):
"""Plots a histogram of number of ooccurences for each prediction.
Args:
ss: a SampleSet
figsize: Width, height of figure in inches.
target_level: The level of the sources multi index to use for legend labels.
ss: a SampleSet with prediction information filled in.
ylim: tuple containing the y-axis min and max values.
yscale: scale of y-axis.
figsize: width, height of figure in inches.
target_level: level of the multi index to use for x-axis labels.
Returns:
A tuple (Figure, Axes) of matplotlib objects.
Raises:
ValueError: raised if `target_level` is not "Isotope" or "Category"
"""
fig, ax = plt.subplots(figsize=figsize)

labels = ss.get_labels(target_level=target_level)
_bin_df_values_and_plot(labels, fig, ax)

ax.set_ylim(ylim)
ax.set_yscale(yscale)
ax.set_title(title)

return fig, ax


@save_or_show_plot
def plot_prediction_distribution(ss: SampleSet, ylim: tuple = (1, None),
yscale: str = "log", figsize: tuple = (12.8, 7.2),
title: str = "Prediction Distribution",
target_level: str = "Isotope"):
"""Plots a histogram of number of ooccurences for each prediction.
Args:
ss: a SampleSet with prediction information filled in.
ylim: tuple containing the y-axis min and max values.
yscale: scale of y-axis.
figsize: width, height of figure in inches.
target_level: level of the multi index to use for x-axis labels.
Returns:
A tuple (Figure, Axes) of matplotlib objects.
"""
if target_level != "Isotope" and target_level != "Category":
raise ValueError(f"Target level of '{target_level}' not supported.")
fig, ax = plt.subplots(figsize=figsize)

level_values = ss.sources.columns.get_level_values(target_level)
labels = ss.get_predictions(target_level=target_level)
_bin_df_values_and_plot(labels, fig, ax)

ax.set_ylim(ylim)
ax.set_yscale(yscale)
ax.set_title(title)

return fig, ax


@save_or_show_plot
def plot_label_and_prediction_distributions(ss: SampleSet, ylim: tuple = (1, None),
yscale: str = "log", figsize: tuple = (12.8, 7.2),
title: str = "Label and Prediction Distribution",
target_level: str = "Isotope"):
"""Plots a histogram of number of ooccurences for each label and prediction.
Args:
ss: a SampleSet with label and prediction information filled in.
ylim: tuple containing the y-axis min and max values.
yscale: scale of y-axis.
figsize: width, height of figure in inches.
target_level: level of the multi index to use for x-axis labels.
Returns:
A tuple (Figure, Axes) of matplotlib objects.
"""
fig, ax = plt.subplots(figsize=figsize)
ax.hist(level_values.values, bins=len(level_values))
ax.set_title("Number of Sources per Label")
ax.set_xticklabels(set(level_values), rotation=45)

labels = ss.get_labels(target_level=target_level)
binned_labels = labels.value_counts()
predictions = ss.get_predictions(target_level=target_level)
binned_predictions = predictions.value_counts()

binned_labels_and_predictions = pd.DataFrame(
[binned_labels, binned_predictions],
index=["Labels", "Predictions"]).T.fillna(0.0)
binned_labels_and_predictions.sort_index(inplace=True)

binned_labels.plot(kind='bar', subplots=True, fig=fig, ax=ax)

ax.set_ylim(ylim)
ax.set_yscale(yscale)
ax.set_title(title)
ax.set_xlabel(target_level)
ax.set_ylabel("Occurences")
fig.tight_layout()

return fig, ax

Expand Down Expand Up @@ -627,36 +689,6 @@ def plot_correlation_between_all_labels(ss: SampleSet, mean: bool = False,
return fig, ax


@save_or_show_plot
def plot_correlation_between_two_labels(ss: SampleSet, label1: str, label2: str,
figsize=(6.4, 4.8), target_level: str = "Isotope"):
"""Plots a correlation matrix of each source against each other source for a pair of labels.
Args:
ss: a SampleSet.
label1: the name of the first label to compare.
label2: the name of the second label to compare.
figsize: Width, height of figure in inches.
target_level: The level of the sources multi index to use for legend labels.
Returns:
A tuple (Figure, Axes) of matplotlib objects.
"""
labels = ss.get_labels(target_level=target_level, minimum_contribution=1)
spectra1 = ss[labels == label1].spectra
spectra2 = ss[labels == label2].spectra
X = spectra1.dot(spectra2.T)

fig, ax = plt.subplots(figsize=figsize)
ax = heatmap(X, annot=False)
ax.set_title(f"Correlation '{label1}' vs. '{label2}'")
ax.set_xlabel(label2)
ax.set_ylabel(label1)

return fig, ax


@save_or_show_plot
def plot_precision_recall(precision, recall, marker="D", lw=2, show_legend=True, fig_ax=None,
title="Precision VS Recall", cmap="gist_ncar",
Expand All @@ -681,7 +713,7 @@ def plot_precision_recall(precision, recall, marker="D", lw=2, show_legend=True,
figsize: Width, height of figure in inches.
"""
from riid.models.metrics import harmonic_mean, average_precision_score
from riid.models.metrics import average_precision_score, harmonic_mean

fig, ax = fig_ax if fig_ax else plt.subplots(figsize=figsize)

Expand Down
2 changes: 1 addition & 1 deletion tests/anomaly_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_event_detector(self):
snr_function_args=(20, 20),
return_gross=True,
rng=np.random.default_rng(42))\
.generate(fg_seeds_ss, mixed_bg_seeds_ss)
.generate(fg_seeds_ss, mixed_bg_seeds_ss, verbose=False)

_, _, gross_events = list(zip(*events))
passby_ss = gross_events[0]
Expand Down
Loading

0 comments on commit 58aa296

Please sign in to comment.