Skip to content

Commit

Permalink
Added OOD abstract classifiers
Browse files Browse the repository at this point in the history
- added abstract classifier for OOD samples that works on top of pretrained neural softmax classifier
- added two concrete implementations: DDU and Malahanobis
- added runnable example
  • Loading branch information
Alberto Gasparin committed Aug 16, 2023
1 parent e3c1a3f commit 527d910
Show file tree
Hide file tree
Showing 8 changed files with 582 additions and 24 deletions.
12 changes: 12 additions & 0 deletions docs/source/references/ood_classifier.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _ood_detection:

Out-Of-Distribution (OOD) detection
==================
Starting from a trained a neural softmax classifier, it's possible to fit one of the models below
to help distinguish between in-distribution and out of distribution inputs.

All the classes below are abstract and in order to be used the ``apply`` method has to be defined.

.. autoclass:: fortuna.ood_detection.mahalanobis.MalahanobisClassifierABC

.. autoclass:: fortuna.ood_detection.ddu.DeepDeterministicUncertaintyABC
1 change: 1 addition & 0 deletions docs/source/references/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ API References
output_calibrator
prob_output_layer
conformal
ood_detection
data_loader
metric
utils
Expand Down
2 changes: 1 addition & 1 deletion examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ In this section we show some examples of how to use Fortuna in classification an
multivalid_coverage
sinusoidal_regression
two_moons_classification
two_moons_classification_ood
subnet_calibration
two_moons_classification_sngp
scaling_up_bayesian_inference
mnist_classification_sghmc
sgmcmc_diagnostics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@

# %% [markdown]

# # Two-moons Classification: Improved uncertainty quantification with SNGP
# # Two-moons Classification: Improved uncertainty quantification

# %% [markdown]
# In this notebook we show how to train an [SNGP](https://arxiv.org/abs/2006.10108) model using Fortuna, showing improved
# uncertainty estimation on the two moons dataset with respect to it's deterministic counterpart.
# In this notebook we will see how to fix model overconfidence over inputs that are far-away from the training data.
# We will do that using two different approaches; let's dive right into it!


# %% [markdown]
# ### Download the Two-Moons data from scikit-learn
# ### Setup
# #### Download the Two-Moons data from scikit-learn
# Let us first download the two-moons data from [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html).

# %%
from matplotlib import colors

TRAIN_DATA_SIZE = 500

Expand All @@ -36,7 +38,7 @@
test_data = make_moons(n_samples=500, noise=0.1, random_state=2)

# %% [markdown]
# ### Convert data to a compatible data loader
# #### Convert data to a compatible data loader
# Fortuna helps you convert data and data loaders into a data loader that Fortuna can digest.

# %%
Expand All @@ -49,35 +51,49 @@
test_data_loader = DataLoader.from_array_data(test_data, batch_size=256, prefetch=True)

# %% [markdown]
# ### Define some utils for plotting the estimated uncertainty
# #### Define some utils for plotting the estimated uncertainty

# %%
import matplotlib.pyplot as plt
import numpy as np
from fortuna.data import InputsLoader
from fortuna.prob_model import ProbClassifier
import jax.numpy as jnp


def get_grid_inputs_loader(grid_size: int = 100):
xx = np.linspace(-4, 4, grid_size)
yy = np.linspace(-4, 4, grid_size)
grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy])
grid_inputs_loader = InputsLoader.from_array_inputs(grid)
grid = grid.reshape(grid_size, grid_size, 2)
return grid, grid_inputs_loader

def plot_uncertainty(
prob_model: ProbClassifier, test_data_loader: DataLoader, grid_size: int = 100

def compute_test_modes(
prob_model: ProbClassifier, test_data_loader: DataLoader
):
test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
return prob_model.predictive.mode(
inputs_loader=test_inputs_loader, means=test_means
)

fig = plt.figure(figsize=(6, 3))
xx = np.linspace(-5, 5, grid_size)
yy = np.linspace(-5, 5, grid_size)
grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy])
grid_loader = InputsLoader.from_array_inputs(grid)
grid_entropies = prob_model.predictive.entropy(grid_loader).reshape(
grid_size, grid_size
)
grid = grid.reshape(grid_size, grid_size, 2)
plt.title("Predictive uncertainty", fontsize=12)
im = plt.pcolor(grid[:, :, 0], grid[:, :, 1], grid_entropies)
def plot_uncertainty_over_grid(
grid: jnp.ndarray, scores: jnp.ndarray, test_modes: jnp.ndarray, title: str = "Predictive uncertainty"
):
scores = scores.reshape(grid.shape[0], grid.shape[1])

_, ax = plt.subplots(figsize=(7, 5.5))
plt.title(title, fontsize=12)
pcm = ax.imshow(
scores.T,
origin="lower",
extent=(-4., 4., -4., 4.),
interpolation='bicubic',
aspect='auto')

# Plot training data.
plt.scatter(
test_data[0][:, 0],
test_data[0][:, 1],
Expand Down Expand Up @@ -126,12 +142,83 @@ def plot_uncertainty(
)

# %%
plot_uncertainty(prob_model, test_data_loader, grid_size=100)
test_modes = compute_test_modes(prob_model, test_data_loader)
grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100)
grid_entropies = prob_model.predictive.entropy(grid_inputs_loader)
plot_uncertainty_over_grid(grid=grid, scores=grid_entropies, test_modes=test_modes)
plt.show()

# %% [markdown]
# Clearly, the model is overconfident on inputs that are far away from the training data.
# This behaviour is not what one would expect, as we rather the model being less confident on out-of-distributin inputs.

# %% [markdown]
# ### Fit an OOD classifier to distinguish between in-distribution and out-of-distribution inputs
# Given the trained model from above, we can now use one of the models provided by Fortuna to actually improve
# the model's confidence on the out-of-distribution inputs.
# In the example below we will use the Malahanobis-based classifier introduced in
# [Lee, Kimin, et al](https://proceedings.neurips.cc/paper/2018/file/abdeb6f575ac5c6676b747bca8d09cc2-Paper.pdf)

# %%
from fortuna.ood_detection.mahalanobis import MalahanobisClassifierABC
from fortuna.model.mlp import DeepResidualFeatureExtractorSubNet
from functools import partial
import jax


class MalahanobisClassifier(MalahanobisClassifierABC):
@partial(jax.jit, static_argnums=(0,))
def apply(self, inputs, params, mutable, **kwargs):
variables = {'params': params["model"]['params']['dfe_subnet'].unfreeze()}
if mutable is not None:
mutable_variables = {k: v['dfe_subnet'].unfreeze() for k, v in mutable["model"].items()}
variables.update(mutable_variables)
return self.feature_extractor_subnet.apply(
variables, inputs, train=False, mutable=False,
)

ood_classifier = MalahanobisClassifier(
feature_extractor_subnet=DeepResidualFeatureExtractorSubNet(
dense=model.dense,
widths=model.widths,
activations=model.activations,
dropout=model.dropout,
dropout_rate=model.dropout_rate,
)
)

# %% [markdown]
# In the code block above we first define a `MalahanobisClassifier` starting from the `MalahanobisClassifierABC`
# provided by Fortuna. The only thing we need to do here is to implement the `apply` method that allow one to transform
# an input vector into an embedding vector.
# Once this is done, we can initialize our classifier by providing the `feature_extractor_subnet`. In the example,
# this is our original model (`DeepResidualNet`) without the output layer.
# We are now ready to fit the classifier using our training data and verify whether the model's overconfidence has been
# (at least partially) fixed:

# %%
state = prob_model.posterior.state.get()
ood_classifier.fit(state=state, train_data_loader=train_data_loader, num_classes=2)
grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100)
grid_scores = ood_classifier.score(state=state, inputs_loader=grid_inputs_loader)
# for the sake of plotting we set a threshold on the OOD classifier scores using the max score
# obtained from a known in-distribution source
ind_scores = ood_classifier.score(state=state, inputs_loader=val_data_loader.to_inputs_loader())
threshold = ind_scores.max()*2
grid_scores = jnp.where(grid_scores < threshold, grid_scores, threshold)
plot_uncertainty_over_grid(grid=grid, scores=grid_scores, test_modes=test_modes, title="OOD scores")
plt.show()


# %% [markdown]
# We will now see a different way of obtaining improved uncertainty estimation
# (for out-of-distribution inputs): [SNGP](https://arxiv.org/abs/2006.10108).
# Unlike before, we now have to retrain the model as the architecture will slighly change.
# The reason for this will be clear from the model definition below.

# %% [markdown]
# ### Define the SNGP model
# Compared to the deterministic model obtained above, SNGP has two crucial differences:
# Compared to the deterministic model obtained in the first part of this notebook, SNGP has two crucial differences:
#
# 1. [Spectral Normalization](https://arxiv.org/abs/1802.05957) is applied to all Dense (or Convolutional) layers.
# 2. The Dense output layer is replaced with a Gaussian Process layer.
Expand Down Expand Up @@ -207,7 +294,10 @@ class SNGPDeepFeatureExtractorSubNet(
)

# %%
plot_uncertainty(prob_model, test_data_loader, grid_size=100)
test_modes = compute_test_modes(prob_model, test_data_loader)
grid, grid_inputs_loader = get_grid_inputs_loader(grid_size=100)
grid_entropies = prob_model.predictive.entropy(grid_inputs_loader)
plot_uncertainty_over_grid(grid=grid, scores=grid_entropies, test_modes=test_modes)
plt.show()

# %% [markdown]
Expand Down
Empty file.
68 changes: 68 additions & 0 deletions fortuna/ood_detection/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import abc
from functools import partial
from typing import (
Tuple,
Union,
)

from flax import linen as nn
from flax.training.checkpoints import PyTree
import jax
from jax import numpy as jnp

from fortuna.data.loader.base import (
BaseDataLoaderABC,
BaseInputsLoader,
)
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.typing import InputData, Params, Mutable


class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting."""


class OutOfDistributionClassifierABC:
"""
Post-training classifier that uses the training sample embeddings coming from the model
to score a (new) test sample w.r.t. its chance of belonging to the original training distribution
(i.e, it is in-distribution) or not (i.e., it is out of distribution).
"""

def __init__(self, feature_extractor_subnet: nn.Module):
"""
Parameters
----------
feature_extractor_subnet: nn.Module
The model (or a part of it) used to obtain the embeddings of any given input.
"""
self.feature_extractor_subnet = feature_extractor_subnet

@abc.abstractmethod
def apply(
self,
inputs: InputData,
params: Params,
mutable: Mutable,
**kwargs,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]:
"""
Transform an input :math:`\mathbf{x}` into an embedding :math:`f(\mathbf{x})`.
"""
pass
# return self.feature_extractor_subnet(**inputs, train=False)[1]

@abc.abstractmethod
def fit(
self,
state: PosteriorState,
train_data_loader: BaseDataLoaderABC,
num_classes: int,
) -> None:
pass

@abc.abstractmethod
def score(
self, state: PosteriorState, inputs_loader: BaseInputsLoader
) -> jnp.ndarray:
pass
Loading

0 comments on commit 527d910

Please sign in to comment.