diff --git a/fortuna/hallucination/__init__.py b/fortuna/hallucination/__init__.py new file mode 100644 index 00000000..8ee23fec --- /dev/null +++ b/fortuna/hallucination/__init__.py @@ -0,0 +1,2 @@ +from fortuna.hallucination.embedding import EmbeddingManager +from fortuna.hallucination.grouping.clustering.base import GroupingModel diff --git a/fortuna/hallucination/embedding.py b/fortuna/hallucination/embedding.py new file mode 100644 index 00000000..9e44bf52 --- /dev/null +++ b/fortuna/hallucination/embedding.py @@ -0,0 +1,35 @@ +from typing import Callable + +import numpy as np +from tqdm import ( + tqdm, + trange, +) + +from fortuna.data import InputsLoader +from fortuna.typing import Array + + +class EmbeddingManager: + def __init__( + self, + encoding_fn: Callable[[Array], Array], + reduction_fn: Callable[[Array], Array], + ): + self.encoding_fn = encoding_fn + self.reduction_fn = reduction_fn + + def get(self, inputs_loader: InputsLoader) -> Array: + embeddings = [] + for inputs in tqdm(inputs_loader, desc="Batch"): + embeddings.append( + np.vstack( + [ + self.encoding_fn(inputs[i]).tolist() + for i in trange(len(inputs), desc="Encode") + ] + ) + ) + embeddings = np.concatenate(embeddings, axis=0) + embeddings = self.reduction_fn(embeddings) + return embeddings diff --git a/fortuna/hallucination/grouping/__init__.py b/fortuna/hallucination/grouping/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/hallucination/grouping/clustering/__init__.py b/fortuna/hallucination/grouping/clustering/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/hallucination/grouping/clustering/base.py b/fortuna/hallucination/grouping/clustering/base.py new file mode 100644 index 00000000..d0c4d2d5 --- /dev/null +++ b/fortuna/hallucination/grouping/clustering/base.py @@ -0,0 +1,178 @@ +from typing import ( + List, + Optional, +) + +import numpy as np + +from fortuna.data import InputsLoader +from fortuna.hallucination.embedding import EmbeddingManager +from fortuna.typing import Array + + +class GroupingModel: + """ + Grouping model based on clustering of embeddings. + """ + + def __init__( + self, embedding_manager: EmbeddingManager, quantile_proba_threshold: float = 0.8 + ): + self.embedding_manager = embedding_manager + self._clustering_model = None + self._embeddings_mean = None + self._embeddings_std = None + self._quantiles = None + self.quantile_proba_threshold = quantile_proba_threshold + + def fit( + self, + inputs_loader: InputsLoader, + clustering_models: List, + extra_embeddings: Optional[Array] = None, + ) -> None: + """ + Fit the model. + + Parameters + ---------- + inputs_loader: InputsLoader + A loader of inputs. + clustering_models: List + A list of clustering models. Each clustering model must include the following method: + extra_embeddings: Optional[Array] + An extra array of embeddings. + + - `fit`, to fit the model; + - `bic`, to compute the BIC; + - `predict_proba`, to get the predicted probabilities; + - `predict`, to get the predictions. + An example of valid clustering model is `sklearn.mixture.GaussianMixture`. + """ + if not isinstance(clustering_models, list) or not len(clustering_models): + raise ValueError("`clustering_models` must be a non-empty list.") + embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings) + self._store_embeddings_stats(embeddings) + embeddings = self._normalize(embeddings) + + best_bic = np.inf + for clustering_model in clustering_models: + model = clustering_model.fit(embeddings) + bic = clustering_model.bic(embeddings) + if bic < best_bic: + best_bic = bic + best_model = model + self._clustering_model = best_model + + probs = self._clustering_model.predict_proba(embeddings) + self._store_thresholds( + probs=probs, quantile_threshold=self.quantile_proba_threshold + ) + + def predict_proba( + self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None + ) -> Array: + """ + For each input, predict the probability of belonging to each cluster. + + Parameters + ---------- + inputs_loader: InputsLoader + A loader of inputs + extra_embeddings: Optional[Array] + An extra array of embeddings. + + Returns + ------- + Array + Predicted probabilities. + """ + if self._clustering_model is None: + raise ValueError("The `fit` method must be run first.") + embeddings = self._get_concat_embeddings(inputs_loader, extra_embeddings) + embeddings = self._normalize(embeddings) + return self._clustering_model.predict_proba(embeddings) + + def soft_predict( + self, + inputs_loader: InputsLoader, + extra_embeddings: Optional[Array] = None, + ) -> Array: + """ + For each input, predict which clusters the inputs are most likely to belong to. + + Parameters + ---------- + inputs_loader: InputsLoader + A loader of inputs + extra_embeddings: Optional[Array] + An extra array of embeddings. + + Returns + ------- + Array + An array of bools determining whether an input is predicted to belong to a cluster or not. + """ + probs = self.predict_proba( + inputs_loader=inputs_loader, extra_embeddings=extra_embeddings + ) + return probs > self._quantiles[None] + + def hard_predict( + self, + inputs_loader: InputsLoader, + extra_embeddings: Optional[Array] = None, + ) -> Array: + """ + For each input, predict which cluster the inputs are most likely to belong to. + + Parameters + ---------- + inputs_loader: InputsLoader + A loader of inputs + extra_embeddings: Optional[Array] + An extra array of embeddings. + + Returns + ------- + Array + An array of bools determining whether an input is predicted to belong to a cluster or not. + Exactly one True will be given for each input. + """ + probs = self.predict_proba( + inputs_loader=inputs_loader, extra_embeddings=extra_embeddings + ) + + bool_preds = np.zeros_like(probs, dtype=bool) + bool_preds[np.arange(len(probs)), np.argmax(probs, axis=1)] = True + return bool_preds + + @property + def clustering_model(self): + return self._clustering_model + + def _store_thresholds(self, probs, quantile_threshold: float) -> None: + self._threshold = np.quantile(probs, quantile_threshold, axis=1) + + def _normalize(self, embeddings: Array) -> Array: + if self._mean is None or self._std is None: + raise ValueError("The `fit` method must be run first.") + embeddings -= self._mean + embeddings /= self._std + return embeddings + + def _get_concat_embeddings( + self, inputs_loader: InputsLoader, extra_embeddings: Optional[Array] = None + ) -> Array: + embeddings = self.embedding_manager.get(inputs_loader) + if extra_embeddings is not None: + if len(embeddings) != len(extra_embeddings): + raise ValueError( + "`The total number of inputs must match the length of `extra_embeddings`." + ) + embeddings = np.concatenate((embeddings, extra_embeddings), axis=1) + return embeddings + + def _store_embeddings_stats(self, embeddings): + self._mean = np.mean(embeddings, axis=0, keepdims=True) + self._std = np.std(embeddings, axis=0, keepdims=True) diff --git a/tests/fortuna/hallucination/embeddings.py b/tests/fortuna/hallucination/embeddings.py new file mode 100644 index 00000000..61168bbf --- /dev/null +++ b/tests/fortuna/hallucination/embeddings.py @@ -0,0 +1,26 @@ +import unittest + +from jax import random + +from fortuna.data import InputsLoader +from fortuna.hallucination.embedding import EmbeddingManager + + +class TestEmbeddingsManager(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_inputs = 10 + self.n_features = 4 + self.n_reduced_features = 3 + self.inputs_loader = InputsLoader.from_array_inputs( + random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)), + batch_size=2, + ) + self.embedding_manager = EmbeddingManager( + encoding_fn=lambda x: 1 - x, + reduction_fn=lambda x: x[:, : self.n_reduced_features], + ) + + def test_get(self): + embeddings = self.embedding_manager.get(self.inputs_loader) + assert embeddings.shape == (self.n_inputs, self.n_reduced_features) diff --git a/tests/fortuna/hallucination/grouping.py b/tests/fortuna/hallucination/grouping.py new file mode 100644 index 00000000..a743ab32 --- /dev/null +++ b/tests/fortuna/hallucination/grouping.py @@ -0,0 +1,87 @@ +import unittest + +from jax import random +import numpy as np +from sklearn.mixture import GaussianMixture + +from fortuna.data import InputsLoader +from fortuna.hallucination.embedding import EmbeddingManager +from fortuna.hallucination.grouping.clustering.base import GroupingModel + + +class GroupingModelTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_inputs = 10 + self.n_features = 4 + self.n_reduced_features = 3 + self.n_extra_features = 5 + self.inputs_loader = InputsLoader.from_array_inputs( + random.normal(random.PRNGKey(0), shape=(self.n_inputs, self.n_features)), + batch_size=2, + ) + self.grouping_model = GroupingModel( + embedding_manager=EmbeddingManager( + encoding_fn=lambda x: 1 - x, + reduction_fn=lambda x: x[:, : self.n_reduced_features], + ) + ) + self.extra_embeddings = random.normal( + random.PRNGKey(0), shape=(self.n_inputs, self.n_extra_features) + ) + self.clustering_models = [GaussianMixture(n_components=i) for i in range(2, 4)] + + def test_all(self): + self.grouping_model.fit( + inputs_loader=self.inputs_loader, + extra_embeddings=None, + clustering_models=self.clustering_models, + ) + self._check_shape_types(extra_embeddings=None) + + self.grouping_model.fit( + inputs_loader=self.inputs_loader, + extra_embeddings=self.extra_embeddings, + clustering_models=self.clustering_models, + ) + self._check_shape_types(extra_embeddings=self.extra_embeddings) + + with self.assertRaises(ValueError): + self.grouping_model.fit( + inputs_loader=self.inputs_loader, + extra_embeddings=None, + clustering_models=[], + ) + + with self.assertRaises(ValueError): + self.grouping_model.fit( + inputs_loader=self.inputs_loader, + extra_embeddings=np.zeros((self.n_inputs + 1, 2)), + clustering_models=[], + ) + + def _check_shape_types(self, extra_embeddings): + probs = self.grouping_model.predict_proba( + inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings + ) + hard_preds = self.grouping_model.hard_predict( + inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings + ) + soft_preds = self.grouping_model.hard_predict( + inputs_loader=self.inputs_loader, extra_embeddings=extra_embeddings + ) + assert probs.shape == ( + self.n_inputs, + self.grouping_model._clustering_model.n_components, + ) + assert soft_preds.shape == ( + self.n_inputs, + self.grouping_model._clustering_model.n_components, + ) + assert hard_preds.shape == ( + self.n_inputs, + self.grouping_model._clustering_model.n_components, + ) + assert soft_preds.dtype == bool + assert hard_preds.dtype == bool + assert np.allclose(hard_preds.sum(1), 1)