-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- added abstract classifier for OOD samples that works on top of pretrained neural softmax classifier - added two concrete implementations: DDU and Malahanobis - added runnable example
- Loading branch information
Alberto Gasparin
committed
Aug 16, 2023
1 parent
e3c1a3f
commit 527d910
Showing
8 changed files
with
582 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
.. _ood_detection: | ||
|
||
Out-Of-Distribution (OOD) detection | ||
================== | ||
Starting from a trained a neural softmax classifier, it's possible to fit one of the models below | ||
to help distinguish between in-distribution and out of distribution inputs. | ||
|
||
All the classes below are abstract and in order to be used the ``apply`` method has to be defined. | ||
|
||
.. autoclass:: fortuna.ood_detection.mahalanobis.MalahanobisClassifierABC | ||
|
||
.. autoclass:: fortuna.ood_detection.ddu.DeepDeterministicUncertaintyABC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ API References | |
output_calibrator | ||
prob_output_layer | ||
conformal | ||
ood_detection | ||
data_loader | ||
metric | ||
utils | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import abc | ||
from functools import partial | ||
from typing import ( | ||
Tuple, | ||
Union, | ||
) | ||
|
||
from flax import linen as nn | ||
from flax.training.checkpoints import PyTree | ||
import jax | ||
from jax import numpy as jnp | ||
|
||
from fortuna.data.loader.base import ( | ||
BaseDataLoaderABC, | ||
BaseInputsLoader, | ||
) | ||
from fortuna.prob_model.posterior.state import PosteriorState | ||
from fortuna.typing import InputData, Params, Mutable | ||
|
||
|
||
class NotFittedError(ValueError, AttributeError): | ||
"""Exception class to raise if estimator is used before fitting.""" | ||
|
||
|
||
class OutOfDistributionClassifierABC: | ||
""" | ||
Post-training classifier that uses the training sample embeddings coming from the model | ||
to score a (new) test sample w.r.t. its chance of belonging to the original training distribution | ||
(i.e, it is in-distribution) or not (i.e., it is out of distribution). | ||
""" | ||
|
||
def __init__(self, feature_extractor_subnet: nn.Module): | ||
""" | ||
Parameters | ||
---------- | ||
feature_extractor_subnet: nn.Module | ||
The model (or a part of it) used to obtain the embeddings of any given input. | ||
""" | ||
self.feature_extractor_subnet = feature_extractor_subnet | ||
|
||
@abc.abstractmethod | ||
def apply( | ||
self, | ||
inputs: InputData, | ||
params: Params, | ||
mutable: Mutable, | ||
**kwargs, | ||
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, PyTree]]: | ||
""" | ||
Transform an input :math:`\mathbf{x}` into an embedding :math:`f(\mathbf{x})`. | ||
""" | ||
pass | ||
# return self.feature_extractor_subnet(**inputs, train=False)[1] | ||
|
||
@abc.abstractmethod | ||
def fit( | ||
self, | ||
state: PosteriorState, | ||
train_data_loader: BaseDataLoaderABC, | ||
num_classes: int, | ||
) -> None: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def score( | ||
self, state: PosteriorState, inputs_loader: BaseInputsLoader | ||
) -> jnp.ndarray: | ||
pass |
Oops, something went wrong.