Skip to content

Commit

Permalink
Merge pull request #87 from deel-ai/feat/simp_hopfield_energy
Browse files Browse the repository at this point in the history
SHE (Simplified Hopfield Energy) OOD detector
  • Loading branch information
y-prudent authored Apr 23, 2024
2 parents e3784c3 + 003d410 commit 61da64c
Show file tree
Hide file tree
Showing 17 changed files with 1,534 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Currently, **oodeel** includes the following baselines:
| Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](docs/notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](docs/notebooks/torch/demo_gram_torch.ipynb) |
| GEN | [GEN: Pushing the Limits of Softmax-Based Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.html) | CVPR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_gen_tf.ipynb) or [torch](docs/notebooks/torch/demo_gen_torch.ipynb) |
| RMDS | [A Simple Fix to Mahalanobis Distance for Improving Near-OOD Detection](https://arxiv.org/abs/2106.09022) | preprint | avail [tensorflow](docs/notebooks/tensorflow/demo_rmds_tf.ipynb) or [torch](docs/notebooks/torch/demo_rmds_torch.ipynb) |
| SHE | [Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization with Modern Hopfield Energy](https://openreview.net/forum?id=KkazG4lgKL) | ICLR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_she_tf.ipynb) or [torch](docs/notebooks/torch/demo_she_torch.ipynb) |



Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ Currently, **oodeel** includes the following baselines:
| NMD | [Neural Mean Discrepancy for Efficient Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2022/html/Dong_Neural_Mean_Discrepancy_for_Efficient_Out-of-Distribution_Detection_CVPR_2022_paper.html) | CVPR 2022 | planned |
| Gram | [Detecting Out-of-Distribution Examples with Gram Matrices](https://proceedings.mlr.press/v119/sastry20a.html) | ICML 2020 | avail [tensorflow](./notebooks/tensorflow/demo_gram_tf.ipynb) or [torch](./notebooks/torch/demo_gram_torch.ipynb) |
| GEN | [GEN: Pushing the Limits of Softmax-Based Out-of-Distribution Detection](https://openaccess.thecvf.com/content/CVPR2023/html/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.html) | CVPR 2023 | avail [tensorflow](./notebooks/tensorflow/demo_gen_tf.ipynb) or [torch](./notebooks/torch/demo_gen_torch.ipynb) |
| SHE | [Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization with Modern Hopfield Energy](https://openreview.net/forum?id=KkazG4lgKL) | ICLR 2023 | avail [tensorflow](docs/notebooks/tensorflow/demo_she_tf.ipynb) or [torch](docs/notebooks/torch/demo_she_torch.ipynb) |


**Oodeel** also includes standard training functions with data augmentation and learning rate scheduler for toy convnet models or models from `keras.applications` in [tf_training_tools.py](https://github.com/deel-ai/oodeel/tree/master/oodeel/utils/tf_training_tools.py) and `torchvision.models` in [torch_training_tools.py](https://github.com/deel-ai/oodeel/tree/master/oodeel/utils/torch_training_tools.py) files. These functions come in handy for benchmarks like *leave-k-classes-out* that requires retraining models on a subset of dataset classes.
Expand Down
592 changes: 592 additions & 0 deletions docs/notebooks/tensorflow/demo_she_tf.ipynb

Large diffs are not rendered by default.

579 changes: 579 additions & 0 deletions docs/notebooks/torch/demo_she_torch.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ nav:
- Gram: notebooks/tensorflow/demo_gram_tf.ipynb
- GEN: notebooks/tensorflow/demo_gen_tf.ipynb
- RMDS: notebooks/tensorflow/demo_rmds_tf.ipynb
- SHE: notebooks/tensorflow/demo_she_tf.ipynb
- OOD Baselines (Torch):
- MLS/MSP: notebooks/torch/demo_mls_msp_torch.ipynb
- ODIN: notebooks/torch/demo_odin_torch.ipynb
Expand All @@ -30,6 +31,7 @@ nav:
- Gram: notebooks/torch/demo_gram_torch.ipynb
- GEN: notebooks/torch/demo_gen_torch.ipynb
- RMDS: notebooks/torch/demo_rmds_torch.ipynb
- SHE: notebooks/torch/demo_she_torch.ipynb
- Advanced Topics:
- Seamlessly handling torch and tf datasets with DataHandler: pages/datahandler_tuto.md
- Seamlessly handling torch and tf Tensors with Operator: pages/operator_tuto.md
Expand Down
2 changes: 2 additions & 0 deletions oodeel/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from .mls import MLS
from .odin import ODIN
from .rmds import RMDS
from .she import SHE
from .vim import VIM

__all__ = [
"DKNN",
"Energy",
"Entropy",
"GEN",
"SHE",
"Gram",
"Mahalanobis",
"MLS",
Expand Down
2 changes: 1 addition & 1 deletion oodeel/methods/gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,4 @@ def requires_internal_features(self) -> bool:
bool: True if the detector perform computations on an intermediate layer
else False.
"""
return False
return True
197 changes: 197 additions & 0 deletions oodeel/methods/she.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import numpy as np

from ..types import DatasetType
from ..types import TensorType
from ..types import Union
from .base import OODBaseDetector


class SHE(OODBaseDetector):
"""
"Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization
with Modern Hopfield Energy"
[link](https://openreview.net/forum?id=KkazG4lgKL)
This method first computes the mean of the internal layer representation of ID data
for each ID class. This mean is seen as the average of the ID activation patterns
as defined in the original paper.
The method then returns the maximum value of the dot product between the internal
layer representation of the input and the average patterns, which is a simplified
version of Hopfield energy as defined in the original paper.
Remarks:
* An input perturbation is applied in the same way as in Mahalanobis score
* The original paper only considers the penultimate layer of the neural
network, while we aggregate the results of multiple layers after normalizing by
the dimension of each vector (the activation vector for dense layers, and the
average pooling of the feature map for convolutional layers).
Args:
eps (float): magnitude for gradient based input perturbation.
Defaults to 0.0014.
"""

def __init__(
self,
eps: float = 0.0014,
):
super().__init__()
self.eps = eps
self.postproc_fns = None

def _postproc_feature_maps(self, feature_map):
if len(feature_map.shape) > 2:
feature_map = self.op.avg_pool_2d(feature_map)
return self.op.flatten(feature_map)

def _fit_to_dataset(
self,
fit_dataset: Union[TensorType, DatasetType],
) -> None:
"""
Compute the means of the input dataset in the activation space of the selected
layers. The means are computed for each class in the dataset.
Args:
fit_dataset (Union[TensorType, DatasetType]): input dataset (ID) to
construct the index with.
ood_dataset (Union[TensorType, DatasetType]): OOD dataset to tune the
aggregation coefficients.
"""
self.postproc_fns = [
self._postproc_feature_maps
for i in range(len(self.feature_extractor.feature_layers_id))
]

features, infos = self.feature_extractor.predict(
fit_dataset, postproc_fns=self.postproc_fns
)

labels = infos["labels"]
preds = self.op.argmax(infos["logits"], dim=-1)
preds = self.op.convert_to_numpy(preds)

# unique sorted classes
self._classes = np.sort(np.unique(self.op.convert_to_numpy(labels)))
labels = self.op.convert_to_numpy(labels)

self._mus = list()
for feature in features:
mus_f = list()
for cls in self._classes:
indexes = np.equal(labels, cls) & np.equal(preds, cls)
_features_cls = feature[indexes]
mus_f.append(
self.op.unsqueeze(self.op.mean(_features_cls, dim=0), dim=0)
)
self._mus.append(self.op.permute(self.op.cat(mus_f), (1, 0)))

def _score_tensor(self, inputs: TensorType) -> np.ndarray:
"""
Computes an OOD score for input samples "inputs" based on
the aggregation of neural mean discrepancies from different layers.
Args:
inputs: input samples to score
Returns:
scores
"""

inputs_p = self._input_perturbation(inputs)
features, logits = self.feature_extractor.predict_tensor(
inputs_p, postproc_fns=self.postproc_fns
)

scores = self._get_she_output(features)

return -self.op.convert_to_numpy(scores)

def _get_she_output(self, features):
scores = None
for feature, mus_f in zip(features, self._mus):
she = self.op.matmul(self.op.squeeze(feature), mus_f) / feature.shape[1]
she = self.op.max(she, dim=1)
scores = she if scores is None else she + scores
return scores

def _input_perturbation(self, inputs: TensorType) -> TensorType:
"""
Apply small perturbation on inputs to make the in- and out- distribution
samples more separable.
Args:
inputs (TensorType): input samples
Returns:
TensorType: Perturbed inputs
"""

def __loss_fn(inputs: TensorType) -> TensorType:
"""
Loss function for the input perturbation.
Args:
inputs (TensorType): input samples
Returns:
TensorType: loss value
"""
# extract features
out_features, _ = self.feature_extractor.predict(
inputs, detach=False, postproc_fns=self.postproc_fns
)
# get mahalanobis score for the class maximizing it
she_score = self._get_she_output(out_features)
log_probs_f = self.op.log(she_score)
return self.op.mean(log_probs_f)

# compute gradient
gradient = self.op.gradient(__loss_fn, inputs)
gradient = self.op.sign(gradient)

inputs_p = inputs - self.eps * gradient
return inputs_p

@property
def requires_to_fit_dataset(self) -> bool:
"""
Whether an OOD detector needs a `fit_dataset` argument in the fit function.
Returns:
bool: True if `fit_dataset` is required else False.
"""
return True

@property
def requires_internal_features(self) -> bool:
"""
Whether an OOD detector acts on internal model features.
Returns:
bool: True if the detector perform computations on an intermediate layer
else False.
"""
return True
17 changes: 17 additions & 0 deletions oodeel/utils/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ def unsqueeze(tensor: TensorType, dim: int) -> TensorType:
"unsqueeze/expand_dim along dim"
raise NotImplementedError()

@staticmethod
def squeeze(tensor: TensorType, dim: int = None) -> TensorType:
"squeeze along dim"
raise NotImplementedError()

@staticmethod
def abs(tensor: TensorType) -> TensorType:
"compute absolute value"
Expand All @@ -234,3 +239,15 @@ def where(
) -> TensorType:
"Applies where function to condition"
raise NotImplementedError()

@staticmethod
@abstractmethod
def avg_pool_2d(tensor: TensorType) -> TensorType:
"""Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
raise NotImplementedError()

@staticmethod
@abstractmethod
def log(tensor: TensorType) -> TensorType:
"""Perform log"""
raise NotImplementedError()
15 changes: 15 additions & 0 deletions oodeel/utils/tf_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ def unsqueeze(tensor: TensorType, dim: int) -> tf.Tensor:
"expand_dim along dim"
return tf.expand_dims(tensor, dim)

@staticmethod
def squeeze(tensor: TensorType, dim: int = None) -> tf.Tensor:
"expand_dim along dim"
return tf.squeeze(tensor, dim)

@staticmethod
def abs(tensor: TensorType) -> tf.Tensor:
"compute absolute value"
Expand All @@ -248,3 +253,13 @@ def where(
@staticmethod
def percentile(x, q):
return tfp.stats.percentile(x, q)

@staticmethod
def avg_pool_2d(tensor: TensorType) -> tf.Tensor:
"""Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
return tf.reduce_mean(tensor, axis=(-3, -2))

@staticmethod
def log(tensor: TensorType) -> tf.Tensor:
"""Perform log"""
return tf.math.log(tensor)
19 changes: 19 additions & 0 deletions oodeel/utils/torch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,15 @@ def unsqueeze(tensor: TensorType, dim: int) -> torch.Tensor:
"unsqueeze along dim"
return torch.unsqueeze(tensor, dim)

@staticmethod
def squeeze(tensor: TensorType, dim: int = None) -> torch.Tensor:
"squeeze along dim"

if dim is None:
return torch.squeeze(tensor)

return torch.squeeze(tensor, dim)

@staticmethod
def abs(tensor: TensorType) -> torch.Tensor:
"compute absolute value"
Expand All @@ -267,3 +276,13 @@ def where(
) -> torch.Tensor:
"Applies where function , to condition"
return torch.where(condition, input, other)

@staticmethod
def avg_pool_2d(tensor: TensorType) -> torch.Tensor:
"""Perform avg pool in 2d as in torch.nn.functional.adaptive_avg_pool2d"""
return torch.mean(tensor, dim=(-2, -1))

@staticmethod
def log(tensor: TensorType) -> torch.Tensor:
"""Perform log"""
return torch.log(tensor)
5 changes: 1 addition & 4 deletions tests/tests_tensorflow/methods/test_tf_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@

def test_gram_shape():
"""
Test Mahalanobis on MNIST vs FashionMNIST OOD dataset-wise task
We check that the area under ROC is above a certain threshold, and that the FPR95TPR
is below an other threshold.
Test Gram method on MNIST vs FashionMNIST OOD dataset-wise task
"""
gram = Gram(orders=range(1, 6))

Expand Down
50 changes: 50 additions & 0 deletions tests/tests_tensorflow/methods/test_tf_she.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from oodeel.methods import SHE
from tests.tests_tensorflow import generate_data_tf
from tests.tests_tensorflow import generate_model


def test_she_shape():
"""
Test SHE on MNIST vs FashionMNIST OOD dataset-wise task
"""
she = SHE()

input_shape = (32, 32, 3)
num_labels = 10
samples = 100

data = generate_data_tf(
x_shape=input_shape, num_labels=num_labels, samples=samples
).batch(samples // 2)

model = generate_model(input_shape=input_shape, output_shape=num_labels)

she.fit(model, data, feature_layers_id=[-5, -2])
score, _ = she.score(data)
assert score.shape == (100,)

she.fit(model, data, feature_layers_id=[-2])
score, _ = she.score(data)
assert score.shape == (100,)
Loading

0 comments on commit 61da64c

Please sign in to comment.