Skip to content

Commit

Permalink
add grouping model for hallucination detection
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Nov 7, 2023
1 parent e6b8c85 commit 7e5e0a9
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 0 deletions.
2 changes: 2 additions & 0 deletions fortuna/hallucination/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from fortuna.hallucination.embedding import EmbeddingManager
from fortuna.hallucination.grouping.clustering.base import GroupingModel
35 changes: 35 additions & 0 deletions fortuna/hallucination/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable

import numpy as np
from tqdm import (
tqdm,
trange,
)

from fortuna.data import InputsLoader
from fortuna.typing import Array


class EmbeddingManager:
def __init__(
self,
encoding_fn: Callable[[Array], Array],
reduction_fn: Callable[[Array], Array],
):
self.encoding_fn = encoding_fn
self.reduction_fn = reduction_fn

def get(self, inputs_loader: InputsLoader) -> Array:
embeddings = []
for inputs in tqdm(inputs_loader, desc="Batch"):
embeddings.append(
np.vstack(
[
self.encoding_fn(inputs[i]).tolist()
for i in trange(len(inputs), desc="Encode")
]
)
)
embeddings = np.concatenate(embeddings, axis=0)
embeddings = self.reduction_fn(embeddings)
return embeddings
Empty file.
Empty file.
178 changes: 178 additions & 0 deletions fortuna/hallucination/grouping/clustering/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import (
List,
Optional,
)

import numpy as np

from fortuna.data import InputsLoader
from fortuna.hallucination.embedding import EmbeddingManager
from fortuna.typing import Array


class GroupingModel:
"""
Grouping model based on clustering of embeddings.
"""

def __init__(
self, embedding_manager: EmbeddingManager, quantile_proba_threshold: float = 0.8
):
self.embedding_manager = embedding_manager
self._clustering_model = None
self._embeddings_mean = None
self._embeddings_std = None
self._quantiles = None
self.quantile_proba_threshold = quantile_proba_threshold

def fit(
self,
inputs_loader: InputsLoader,
clustering_models: List,
extra_embeddings: Optional[Array] = None,
) -> None:
"""
Fit the model.
Parameters
----------
inputs_loader: InputsLoader
A loader of inputs.
clustering_models: List
A list of clustering models. Each clustering model must include the following method:
extra_embeddings: Optional[Array]
An extra array of embeddings.
- `fit`, to fit the model;
- `bic`, to compute the BIC;
- `predict_proba`, to get the predicted probabilities;
- `predict`, to get the predictions.
An example of valid clustering model is `sklearn.mixture.GaussianMixture`.
"""
if not isinstance(clustering_models, list) or not len(clustering_models):
raise ValueError("`clustering_models` must be a non-empty list.")
embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings)
self._store_embeddings_stats(embeddings)
embeddings = self._normalize(embeddings)

best_bic = np.inf
for clustering_model in clustering_models:
model = clustering_model.fit(embeddings)
bic = clustering_model.bic(embeddings)
if bic < best_bic:
best_bic = bic
best_model = model
self._clustering_model = best_model

probs = self._clustering_model.predict_proba(embeddings)
self._store_thresholds(
probs=probs, quantile_threshold=self.quantile_proba_threshold
)

def predict_proba(
self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None
) -> Array:
"""
For each input, predict the probability of belonging to each cluster.
Parameters
----------
inputs_loader: InputsLoader
A loader of inputs
extra_embeddings: Optional[Array]
An extra array of embeddings.
Returns
-------
Array
Predicted probabilities.
"""
if self._clustering_model is None:
raise ValueError("The `fit` method must be run first.")
embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings)
embeddings = self._normalize(embeddings)
return self._clustering_model.predict_proba(embeddings)

def soft_predict(
self,
inputs_loader: InputsLoader,
extra_embeddings: Optional[Array] = None,
) -> Array:
"""
For each input, predict which clusters the inputs are most likely to belong to.
Parameters
----------
inputs_loader: InputsLoader
A loader of inputs
extra_embeddings: Optional[Array]
An extra array of embeddings.
Returns
-------
Array
An array of bools determining whether an input is predicted to belong to a cluster or not.
"""
probs = self.predict_proba(
inputs_loader=inputs_loader, extra_embeddings=extra_embeddings
)
return probs > self._quantiles[None]

def hard_predict(
self,
inputs_loader: InputsLoader,
extra_embeddings: Optional[Array] = None,
) -> Array:
"""
For each input, predict which cluster the inputs are most likely to belong to.
Parameters
----------
inputs_loader: InputsLoader
A loader of inputs
extra_embeddings: Optional[Array]
An extra array of embeddings.
Returns
-------
Array
An array of bools determining whether an input is predicted to belong to a cluster or not.
Exactly one True will be given for each input.
"""
probs = self.predict_proba(
inputs_loader=inputs_loader, extra_embeddings=extra_embeddings
)

bool_preds = np.zeros_like(probs, dtype=bool)
bool_preds[np.arange(len(probs)), np.argmax(probs, axis=1)] = True
return bool_preds

@property
def clustering_model(self):
return self._clustering_model

def _store_thresholds(self, probs, quantile_threshold: float) -> None:
self._threshold = np.quantile(probs, quantile_threshold, axis=1)

def _normalize(self, embeddings: Array) -> Array:
if self._mean is None or self._std is None:
raise ValueError("The `fit` method must be run first.")
embeddings -= self._mean
embeddings /= self._std
return embeddings

def _get_concat_embeddings(
self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None
) -> Array:
embeddings = self.embedding_manager.get(inputs_loader)
if extra_embeddings is not None:
if len(embeddings) != len(extra_embeddings):
raise ValueError(
"`The total number of inputs must match the length of `extra_embeddings`."
)
embeddings = np.concatenate((embeddings, extra_embeddings), axis=1)
return embeddings

def _store_embeddings_stats(self, embeddings):
self._mean = np.mean(embeddings, axis=0, keepdims=True)
self._std = np.std(embeddings, axis=0, keepdims=True)
26 changes: 26 additions & 0 deletions tests/fortuna/hallucination/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest

from jax import random

from fortuna.data import InputsLoader
from fortuna.hallucination.embedding import EmbeddingManager


class TestEmbeddingsManager(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_inputs = 10
self.n_features = 4
self.n_reduced_features = 3
self.inputs_loader = InputsLoader.from_array_inputs(
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)),
batch_size=2,
)
self.embedding_manager = EmbeddingManager(
encoding_fn=lambda x: 1 - x,
reduction_fn=lambda x: x[:, : self.n_reduced_features],
)

def test_get(self):
embeddings = self.embedding_manager.get(self.inputs_loader)
assert embeddings.shape == (self.n_inputs, self.n_reduced_features)
87 changes: 87 additions & 0 deletions tests/fortuna/hallucination/grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest

from jax import random
import numpy as np
from sklearn.mixture import GaussianMixture

from fortuna.data import InputsLoader
from fortuna.hallucination.embedding import EmbeddingManager
from fortuna.hallucination.grouping.clustering.base import GroupingModel


class GroupingModelTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_inputs = 10
self.n_features = 4
self.n_reduced_features = 3
self.n_extra_features = 5
self.inputs_loader = InputsLoader.from_array_inputs(
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)),
batch_size=2,
)
self.grouping_model = GroupingModel(
embedding_manager=EmbeddingManager(
encoding_fn=lambda x: 1 - x,
reduction_fn=lambda x: x[:, : self.n_reduced_features],
)
)
self.extra_embeddings = random.normal(
random.PRNGKey(0), shape=(self.n_inputs, self.n_extra_features)
)
self.clustering_models = [GaussianMixture(n_components=i) for i in range(2, 4)]

def test_all(self):
self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=None,
clustering_models=self.clustering_models,
)
self._check_shape_types(extra_embeddings=None)

self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=self.extra_embeddings,
clustering_models=self.clustering_models,
)
self._check_shape_types(extra_embeddings=self.extra_embeddings)

with self.assertRaises(ValueError):
self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=None,
clustering_models=[],
)

with self.assertRaises(ValueError):
self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=np.zeros((self.n_inputs + 1, 2)),
clustering_models=[],
)

def _check_shape_types(self, extra_embeddings):
probs = self.grouping_model.predict_proba(
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
)
hard_preds = self.grouping_model.hard_predict(
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
)
soft_preds = self.grouping_model.hard_predict(
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
)
assert probs.shape == (
self.n_inputs,
self.grouping_model._clustering_model.n_components,
)
assert soft_preds.shape == (
self.n_inputs,
self.grouping_model._clustering_model.n_components,
)
assert hard_preds.shape == (
self.n_inputs,
self.grouping_model._clustering_model.n_components,
)
assert soft_preds.dtype == bool
assert hard_preds.dtype == bool
assert np.allclose(hard_preds.sum(1), 1)

0 comments on commit 7e5e0a9

Please sign in to comment.