-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add grouping model for hallucination detection
- Loading branch information
1 parent
e6b8c85
commit 7e5e0a9
Showing
7 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from fortuna.hallucination.embedding import EmbeddingManager | ||
from fortuna.hallucination.grouping.clustering.base import GroupingModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Callable | ||
|
||
import numpy as np | ||
from tqdm import ( | ||
tqdm, | ||
trange, | ||
) | ||
|
||
from fortuna.data import InputsLoader | ||
from fortuna.typing import Array | ||
|
||
|
||
class EmbeddingManager: | ||
def __init__( | ||
self, | ||
encoding_fn: Callable[[Array], Array], | ||
reduction_fn: Callable[[Array], Array], | ||
): | ||
self.encoding_fn = encoding_fn | ||
self.reduction_fn = reduction_fn | ||
|
||
def get(self, inputs_loader: InputsLoader) -> Array: | ||
embeddings = [] | ||
for inputs in tqdm(inputs_loader, desc="Batch"): | ||
embeddings.append( | ||
np.vstack( | ||
[ | ||
self.encoding_fn(inputs[i]).tolist() | ||
for i in trange(len(inputs), desc="Encode") | ||
] | ||
) | ||
) | ||
embeddings = np.concatenate(embeddings, axis=0) | ||
embeddings = self.reduction_fn(embeddings) | ||
return embeddings |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from typing import ( | ||
List, | ||
Optional, | ||
) | ||
|
||
import numpy as np | ||
|
||
from fortuna.data import InputsLoader | ||
from fortuna.hallucination.embedding import EmbeddingManager | ||
from fortuna.typing import Array | ||
|
||
|
||
class GroupingModel: | ||
""" | ||
Grouping model based on clustering of embeddings. | ||
""" | ||
|
||
def __init__( | ||
self, embedding_manager: EmbeddingManager, quantile_proba_threshold: float = 0.8 | ||
): | ||
self.embedding_manager = embedding_manager | ||
self._clustering_model = None | ||
self._embeddings_mean = None | ||
self._embeddings_std = None | ||
self._quantiles = None | ||
self.quantile_proba_threshold = quantile_proba_threshold | ||
|
||
def fit( | ||
self, | ||
inputs_loader: InputsLoader, | ||
clustering_models: List, | ||
extra_embeddings: Optional[Array] = None, | ||
) -> None: | ||
""" | ||
Fit the model. | ||
Parameters | ||
---------- | ||
inputs_loader: InputsLoader | ||
A loader of inputs. | ||
clustering_models: List | ||
A list of clustering models. Each clustering model must include the following method: | ||
extra_embeddings: Optional[Array] | ||
An extra array of embeddings. | ||
- `fit`, to fit the model; | ||
- `bic`, to compute the BIC; | ||
- `predict_proba`, to get the predicted probabilities; | ||
- `predict`, to get the predictions. | ||
An example of valid clustering model is `sklearn.mixture.GaussianMixture`. | ||
""" | ||
if not isinstance(clustering_models, list) or not len(clustering_models): | ||
raise ValueError("`clustering_models` must be a non-empty list.") | ||
embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings) | ||
self._store_embeddings_stats(embeddings) | ||
embeddings = self._normalize(embeddings) | ||
|
||
best_bic = np.inf | ||
for clustering_model in clustering_models: | ||
model = clustering_model.fit(embeddings) | ||
bic = clustering_model.bic(embeddings) | ||
if bic < best_bic: | ||
best_bic = bic | ||
best_model = model | ||
self._clustering_model = best_model | ||
|
||
probs = self._clustering_model.predict_proba(embeddings) | ||
self._store_thresholds( | ||
probs=probs, quantile_threshold=self.quantile_proba_threshold | ||
) | ||
|
||
def predict_proba( | ||
self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None | ||
) -> Array: | ||
""" | ||
For each input, predict the probability of belonging to each cluster. | ||
Parameters | ||
---------- | ||
inputs_loader: InputsLoader | ||
A loader of inputs | ||
extra_embeddings: Optional[Array] | ||
An extra array of embeddings. | ||
Returns | ||
------- | ||
Array | ||
Predicted probabilities. | ||
""" | ||
if self._clustering_model is None: | ||
raise ValueError("The `fit` method must be run first.") | ||
embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings) | ||
embeddings = self._normalize(embeddings) | ||
return self._clustering_model.predict_proba(embeddings) | ||
|
||
def soft_predict( | ||
self, | ||
inputs_loader: InputsLoader, | ||
extra_embeddings: Optional[Array] = None, | ||
) -> Array: | ||
""" | ||
For each input, predict which clusters the inputs are most likely to belong to. | ||
Parameters | ||
---------- | ||
inputs_loader: InputsLoader | ||
A loader of inputs | ||
extra_embeddings: Optional[Array] | ||
An extra array of embeddings. | ||
Returns | ||
------- | ||
Array | ||
An array of bools determining whether an input is predicted to belong to a cluster or not. | ||
""" | ||
probs = self.predict_proba( | ||
inputs_loader=inputs_loader, extra_embeddings=extra_embeddings | ||
) | ||
return probs > self._quantiles[None] | ||
|
||
def hard_predict( | ||
self, | ||
inputs_loader: InputsLoader, | ||
extra_embeddings: Optional[Array] = None, | ||
) -> Array: | ||
""" | ||
For each input, predict which cluster the inputs are most likely to belong to. | ||
Parameters | ||
---------- | ||
inputs_loader: InputsLoader | ||
A loader of inputs | ||
extra_embeddings: Optional[Array] | ||
An extra array of embeddings. | ||
Returns | ||
------- | ||
Array | ||
An array of bools determining whether an input is predicted to belong to a cluster or not. | ||
Exactly one True will be given for each input. | ||
""" | ||
probs = self.predict_proba( | ||
inputs_loader=inputs_loader, extra_embeddings=extra_embeddings | ||
) | ||
|
||
bool_preds = np.zeros_like(probs, dtype=bool) | ||
bool_preds[np.arange(len(probs)), np.argmax(probs, axis=1)] = True | ||
return bool_preds | ||
|
||
@property | ||
def clustering_model(self): | ||
return self._clustering_model | ||
|
||
def _store_thresholds(self, probs, quantile_threshold: float) -> None: | ||
self._threshold = np.quantile(probs, quantile_threshold, axis=1) | ||
|
||
def _normalize(self, embeddings: Array) -> Array: | ||
if self._mean is None or self._std is None: | ||
raise ValueError("The `fit` method must be run first.") | ||
embeddings -= self._mean | ||
embeddings /= self._std | ||
return embeddings | ||
|
||
def _get_concat_embeddings( | ||
self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None | ||
) -> Array: | ||
embeddings = self.embedding_manager.get(inputs_loader) | ||
if extra_embeddings is not None: | ||
if len(embeddings) != len(extra_embeddings): | ||
raise ValueError( | ||
"`The total number of inputs must match the length of `extra_embeddings`." | ||
) | ||
embeddings = np.concatenate((embeddings, extra_embeddings), axis=1) | ||
return embeddings | ||
|
||
def _store_embeddings_stats(self, embeddings): | ||
self._mean = np.mean(embeddings, axis=0, keepdims=True) | ||
self._std = np.std(embeddings, axis=0, keepdims=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
|
||
from jax import random | ||
|
||
from fortuna.data import InputsLoader | ||
from fortuna.hallucination.embedding import EmbeddingManager | ||
|
||
|
||
class TestEmbeddingsManager(unittest.TestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.n_inputs = 10 | ||
self.n_features = 4 | ||
self.n_reduced_features = 3 | ||
self.inputs_loader = InputsLoader.from_array_inputs( | ||
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)), | ||
batch_size=2, | ||
) | ||
self.embedding_manager = EmbeddingManager( | ||
encoding_fn=lambda x: 1 - x, | ||
reduction_fn=lambda x: x[:, : self.n_reduced_features], | ||
) | ||
|
||
def test_get(self): | ||
embeddings = self.embedding_manager.get(self.inputs_loader) | ||
assert embeddings.shape == (self.n_inputs, self.n_reduced_features) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import unittest | ||
|
||
from jax import random | ||
import numpy as np | ||
from sklearn.mixture import GaussianMixture | ||
|
||
from fortuna.data import InputsLoader | ||
from fortuna.hallucination.embedding import EmbeddingManager | ||
from fortuna.hallucination.grouping.clustering.base import GroupingModel | ||
|
||
|
||
class GroupingModelTest(unittest.TestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.n_inputs = 10 | ||
self.n_features = 4 | ||
self.n_reduced_features = 3 | ||
self.n_extra_features = 5 | ||
self.inputs_loader = InputsLoader.from_array_inputs( | ||
random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)), | ||
batch_size=2, | ||
) | ||
self.grouping_model = GroupingModel( | ||
embedding_manager=EmbeddingManager( | ||
encoding_fn=lambda x: 1 - x, | ||
reduction_fn=lambda x: x[:, : self.n_reduced_features], | ||
) | ||
) | ||
self.extra_embeddings = random.normal( | ||
random.PRNGKey(0), shape=(self.n_inputs, self.n_extra_features) | ||
) | ||
self.clustering_models = [GaussianMixture(n_components=i) for i in range(2, 4)] | ||
|
||
def test_all(self): | ||
self.grouping_model.fit( | ||
inputs_loader=self.inputs_loader, | ||
extra_embeddings=None, | ||
clustering_models=self.clustering_models, | ||
) | ||
self._check_shape_types(extra_embeddings=None) | ||
|
||
self.grouping_model.fit( | ||
inputs_loader=self.inputs_loader, | ||
extra_embeddings=self.extra_embeddings, | ||
clustering_models=self.clustering_models, | ||
) | ||
self._check_shape_types(extra_embeddings=self.extra_embeddings) | ||
|
||
with self.assertRaises(ValueError): | ||
self.grouping_model.fit( | ||
inputs_loader=self.inputs_loader, | ||
extra_embeddings=None, | ||
clustering_models=[], | ||
) | ||
|
||
with self.assertRaises(ValueError): | ||
self.grouping_model.fit( | ||
inputs_loader=self.inputs_loader, | ||
extra_embeddings=np.zeros((self.n_inputs + 1, 2)), | ||
clustering_models=[], | ||
) | ||
|
||
def _check_shape_types(self, extra_embeddings): | ||
probs = self.grouping_model.predict_proba( | ||
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings | ||
) | ||
hard_preds = self.grouping_model.hard_predict( | ||
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings | ||
) | ||
soft_preds = self.grouping_model.hard_predict( | ||
inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings | ||
) | ||
assert probs.shape == ( | ||
self.n_inputs, | ||
self.grouping_model._clustering_model.n_components, | ||
) | ||
assert soft_preds.shape == ( | ||
self.n_inputs, | ||
self.grouping_model._clustering_model.n_components, | ||
) | ||
assert hard_preds.shape == ( | ||
self.n_inputs, | ||
self.grouping_model._clustering_model.n_components, | ||
) | ||
assert soft_preds.dtype == bool | ||
assert hard_preds.dtype == bool | ||
assert np.allclose(hard_preds.sum(1), 1) |