Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouping mode for hallucination detection #149

Merged
merged 65 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
52e96ea
edit installation instructions in readme
gianlucadetommaso May 15, 2023
5e0076d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
4c7fd28
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
6cb6581
bump up version
gianlucadetommaso May 15, 2023
1b39780
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
cb2b49a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
14e3ca4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 25, 2023
580067d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 27, 2023
048ef09
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 2, 2023
ad542a4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
41417c1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
64be374
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
a2d0f34
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
66bba06
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
911aa82
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
01f959b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
79f8dca
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
4dea50f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 21, 2023
1ced008
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6992692
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
b2540c1
make small change in readme because of publish to pypi error
gianlucadetommaso Jul 18, 2023
2362998
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6e030f2
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
9bd6f67
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
c5bc94f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
d3ab46b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 26, 2023
0e2aca5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 26, 2023
9520273
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
e9c4108
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
bc64a01
bump up version
gianlucadetommaso Jul 30, 2023
25072da
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
e27b378
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
a175e16
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 1, 2023
6e202f1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 1, 2023
635e7c9
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 9, 2023
8e23b32
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 16, 2023
f5efef8
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 24, 2023
958b245
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 24, 2023
577d169
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 28, 2023
69a454e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 30, 2023
6e880ba
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 30, 2023
f606545
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 11, 2023
63e09bb
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 11, 2023
b2402b5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 12, 2023
591d842
refactor tabular analysis of benchmarks
gianlucadetommaso Sep 13, 2023
3dcf217
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 13, 2023
d1b5b4a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 18, 2023
b4c161e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 21, 2023
744dff1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 21, 2023
a22f97f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 24, 2023
fffdd76
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 26, 2023
c23d16d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 26, 2023
1cb2917
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 27, 2023
9c1d07a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 29, 2023
4b83638
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 10, 2023
610fc37
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 10, 2023
e5b67ba
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 10, 2023
1f03d4e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 10, 2023
d49ed29
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 11, 2023
8200e42
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 19, 2023
882733b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 19, 2023
c8ca7e6
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 27, 2023
b1e67fc
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 30, 2023
e6b8c85
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Oct 30, 2023
7e5e0a9
add grouping model for hallucination detection
gianlucadetommaso Nov 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fortuna/hallucination/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from fortuna.hallucination.embedding import EmbeddingManager
from fortuna.hallucination.grouping.clustering.base import GroupingModel
35 changes: 35 additions & 0 deletions fortuna/hallucination/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable

import numpy as np
from tqdm import (
tqdm,
trange,
)

from fortuna.data import InputsLoader
from fortuna.typing import Array


class EmbeddingManager:
def __init__(
self,
encoding_fn: Callable[[Array], Array],
reduction_fn: Callable[[Array], Array],
):
self.encoding_fn = encoding_fn
self.reduction_fn = reduction_fn

def get(self, inputs_loader: InputsLoader) -> Array:
embeddings = []
for inputs in tqdm(inputs_loader, desc="Batch"):
embeddings.append(
np.vstack(
[
self.encoding_fn(inputs[i]).tolist()
for i in trange(len(inputs), desc="Encode")
]
)
)
embeddings = np.concatenate(embeddings, axis=0)
embeddings = self.reduction_fn(embeddings)
return embeddings
Empty file.
Empty file.
178 changes: 178 additions & 0 deletions fortuna/hallucination/grouping/clustering/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import (
List,
Optional,
)

import numpy as np

from fortuna.data import InputsLoader
from fortuna.hallucination.embedding import EmbeddingManager
from fortuna.typing import Array


class GroupingModel:
"""
Grouping model based on clustering of embeddings.
"""

def __init__(
self, embedding_manager: EmbeddingManager, quantile_proba_threshold: float = 0.8
):
self.embedding_manager = embedding_manager
self._clustering_model = None
self._embeddings_mean = None
self._embeddings_std = None
self._quantiles = None
self.quantile_proba_threshold = quantile_proba_threshold

def fit(
self,
inputs_loader: InputsLoader,
clustering_models: List,
extra_embeddings: Optional[Array] = None,
) -> None:
"""
Fit the model.

Parameters
----------
inputs_loader: InputsLoader
A loader of inputs.
clustering_models: List
A list of clustering models. Each clustering model must include the following method:
extra_embeddings: Optional[Array]
An extra array of embeddings.

- `fit`, to fit the model;
- `bic`, to compute the BIC;
- `predict_proba`, to get the predicted probabilities;
- `predict`, to get the predictions.
An example of valid clustering model is `sklearn.mixture.GaussianMixture`.
"""
if not isinstance(clustering_models, list) or not len(clustering_models):
raise ValueError("`clustering_models` must be a non-empty list.")
embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings)
self._store_embeddings_stats(embeddings)
embeddings = self._normalize(embeddings)

best_bic = np.inf
for clustering_model in clustering_models:
model = clustering_model.fit(embeddings)
bic = clustering_model.bic(embeddings)
if bic < best_bic:
best_bic = bic
best_model = model
self._clustering_model = best_model

probs = self._clustering_model.predict_proba(embeddings)
self._store_thresholds(
probs=probs, quantile_threshold=self.quantile_proba_threshold
)

def predict_proba(
self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None
) -> Array:
"""
For each input, predict the probability of belonging to each cluster.

Parameters
----------
inputs_loader: InputsLoader
A loader of inputs
extra_embeddings: Optional[Array]
An extra array of embeddings.

Returns
-------
Array
Predicted probabilities.
"""
if self._clustering_model is None:
raise ValueError("The `fit` method must be run first.")
embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings)
embeddings = self._normalize(embeddings)
return self._clustering_model.predict_proba(embeddings)

def soft_predict(
self,
inputs_loader: InputsLoader,
extra_embeddings: Optional[Array] = None,
) -> Array:
"""
For each input, predict which clusters the inputs are most likely to belong to.

Parameters
----------
inputs_loader: InputsLoader
A loader of inputs
extra_embeddings: Optional[Array]
An extra array of embeddings.

Returns
-------
Array
An array of bools determining whether an input is predicted to belong to a cluster or not.
"""
probs = self.predict_proba(
inputs_loader=inputs_loader, extra_embeddings=extra_embeddings
)
return probs > self._quantiles[None]

def hard_predict(
self,
inputs_loader: InputsLoader,
extra_embeddings: Optional[Array] = None,
) -> Array:
"""
For each input, predict which cluster the inputs are most likely to belong to.

Parameters
----------
inputs_loader: InputsLoader
A loader of inputs
extra_embeddings: Optional[Array]
An extra array of embeddings.

Returns
-------
Array
An array of bools determining whether an input is predicted to belong to a cluster or not.
Exactly one True will be given for each input.
"""
probs = self.predict_proba(
inputs_loader=inputs_loader, extra_embeddings=extra_embeddings
)

bool_preds = np.zeros_like(probs, dtype=bool)
bool_preds[np.arange(len(probs)), np.argmax(probs, axis=1)] = True
return bool_preds

@property
def clustering_model(self):
return self._clustering_model

def _store_thresholds(self, probs, quantile_threshold: float) -> None:
self._threshold = np.quantile(probs, quantile_threshold, axis=1)

def _normalize(self, embeddings: Array) -> Array:
if self._mean is None or self._std is None:
raise ValueError("The `fit` method must be run first.")
embeddings -= self._mean
embeddings /= self._std
return embeddings

def _get_concat_embeddings(
self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None
) -> Array:
embeddings = self.embedding_manager.get(inputs_loader)
if extra_embeddings is not None:
if len(embeddings) != len(extra_embeddings):
raise ValueError(
"`The total number of inputs must match the length of `extra_embeddings`."
)
embeddings = np.concatenate((embeddings, extra_embeddings), axis=1)
return embeddings

def _store_embeddings_stats(self, embeddings):
self._mean = np.mean(embeddings, axis=0, keepdims=True)
self._std = np.std(embeddings, axis=0, keepdims=True)
26 changes: 26 additions & 0 deletions tests/fortuna/hallucination/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest

from jax import random

from fortuna.data import InputsLoader
from fortuna.hallucination.embedding import EmbeddingManager


class TestEmbeddingsManager(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_inputs = 10
self.n_features = 4
self.n_reduced_features = 3
self.inputs_loader = InputsLoader.from_array_inputs(
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)),
batch_size=2,
)
self.embedding_manager = EmbeddingManager(
encoding_fn=lambda x: 1 - x,
reduction_fn=lambda x: x[:, : self.n_reduced_features],
)

def test_get(self):
embeddings = self.embedding_manager.get(self.inputs_loader)
assert embeddings.shape == (self.n_inputs, self.n_reduced_features)
87 changes: 87 additions & 0 deletions tests/fortuna/hallucination/grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest

from jax import random
import numpy as np
from sklearn.mixture import GaussianMixture

from fortuna.data import InputsLoader
from fortuna.hallucination.embedding import EmbeddingManager
from fortuna.hallucination.grouping.clustering.base import GroupingModel


class GroupingModelTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_inputs = 10
self.n_features = 4
self.n_reduced_features = 3
self.n_extra_features = 5
self.inputs_loader = InputsLoader.from_array_inputs(
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)),
batch_size=2,
)
self.grouping_model = GroupingModel(
embedding_manager=EmbeddingManager(
encoding_fn=lambda x: 1 - x,
reduction_fn=lambda x: x[:, : self.n_reduced_features],
)
)
self.extra_embeddings = random.normal(
random.PRNGKey(0), shape=(self.n_inputs, self.n_extra_features)
)
self.clustering_models = [GaussianMixture(n_components=i) for i in range(2, 4)]

def test_all(self):
self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=None,
clustering_models=self.clustering_models,
)
self._check_shape_types(extra_embeddings=None)

self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=self.extra_embeddings,
clustering_models=self.clustering_models,
)
self._check_shape_types(extra_embeddings=self.extra_embeddings)

with self.assertRaises(ValueError):
self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=None,
clustering_models=[],
)

with self.assertRaises(ValueError):
self.grouping_model.fit(
inputs_loader=self.inputs_loader,
extra_embeddings=np.zeros((self.n_inputs + 1, 2)),
clustering_models=[],
)

def _check_shape_types(self, extra_embeddings):
probs = self.grouping_model.predict_proba(
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
)
hard_preds = self.grouping_model.hard_predict(
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
)
soft_preds = self.grouping_model.hard_predict(
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings
)
assert probs.shape == (
self.n_inputs,
self.grouping_model._clustering_model.n_components,
)
assert soft_preds.shape == (
self.n_inputs,
self.grouping_model._clustering_model.n_components,
)
assert hard_preds.shape == (
self.n_inputs,
self.grouping_model._clustering_model.n_components,
)
assert soft_preds.dtype == bool
assert hard_preds.dtype == bool
assert np.allclose(hard_preds.sum(1), 1)
Loading