Skip to content

Commit

Permalink
GH-3496: Add OneClassClassifier model
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffpicard committed Aug 6, 2024
1 parent c6a2643 commit 65459d3
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 0 deletions.
243 changes: 243 additions & 0 deletions flair/models/one_class_classification_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

import flair
from flair.data import Dictionary, Sentence, _iter_dataset
from flair.datasets import DataLoader, FlairDatapointDataset
from flair.embeddings import DocumentEmbeddings
from flair.training_utils import store_embeddings


class OneClassClassifier(flair.nn.Classifier[Sentence]):
"""One Class Classification Model for tasks such as Anomaly Detection.
Task
----
One Class Classification (OCC) tries to identify objects of a specific class amongst all objects, in contrast to
distinguishing between two or more classes.
Example:
-------
The model expects to be trained on a dataset in which every element has the same label_value, e.g. movie reviews
with the label POSITIVE.
During inference, one of two label_values will be added:
- In-class (e.g. another movie review) -> label_value="POSITIVE"
- Anything else (e.g. a wiki page) -> label_value="<unk>"
Architecture
------------
Reconstruction with autoencoder. The score is the reconstruction error from compressing and decompressing the
document embedding. A LOWER score indicates a HIGHER probability of being in-class. The threshold is
calculated as a high percentile of the score distribution of in-class elements from the dev set.
You must set the threshold after training by running `model.threshold = model.calculate_threshold(corpus.dev)`.
"""

def __init__(
self,
embeddings: DocumentEmbeddings,
label_dictionary: Dictionary,
label_type: str,
encoding_dim: int = 128,
threshold: Optional[float] = None,
) -> None:
"""Initializes a OneClassClassifier.
Args:
embeddings: Embeddings to use during training and prediction
label_dictionary: The label to predict. Must contain exactly one class.
label_type: name of the annotation_layer to be predicted in case a corpus has multiple annotations
encoding_dim: The size of the compressed embedding
threshold: The score that separates in-class from out-of-class
"""
super().__init__()
self.embeddings = embeddings
if len(label_dictionary) != 1:
raise ValueError(f"label_dictionary must have exactly 1 element: {label_dictionary}")
self.label_dictionary = label_dictionary
self.label_value = label_dictionary.get_items()[0]
self._label_type = label_type
self.encoding_dim = encoding_dim
self.threshold = threshold

embedding_dim = embeddings.embedding_length
self.encoder = torch.nn.Sequential(
torch.nn.Linear(embedding_dim, encoding_dim * 4),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim * 4, encoding_dim * 2),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim * 2, encoding_dim),
torch.nn.LeakyReLU(True),
)

self.decoder = torch.nn.Sequential(
torch.nn.Linear(encoding_dim, encoding_dim * 2),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim * 2, encoding_dim * 4),
torch.nn.LeakyReLU(True),
torch.nn.Linear(encoding_dim * 4, embedding_dim),
torch.nn.LeakyReLU(True),
)

self.cosine_sim = torch.nn.CosineSimilarity(dim=1)
self.to(flair.device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.decoder(x)
return x

def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]:
"""Returns Tuple[scalar tensor, num examples]."""
if len(sentences) == 0:
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
sentence_tensor = self._sentences_to_tensor(sentences)
reconstructed_sentence_tensor = self.forward(sentence_tensor)
return self._calculate_loss(reconstructed_sentence_tensor, sentence_tensor).sum(), len(sentences)

def predict(
self,
sentences: Union[List[Sentence], Sentence],
mini_batch_size: int = 32,
return_probabilities_for_all_classes: bool = False,
verbose: bool = False,
label_name: Optional[str] = None,
return_loss=False,
embedding_storage_mode="none",
) -> Optional[torch.Tensor]:
"""Predicts the class labels for the given sentences. The labels are directly added to the sentences.
Args:
sentences: list of sentences to predict
mini_batch_size: the amount of sentences that will be predicted within one batch
return_probabilities_for_all_classes: doesn't do anything for this class
verbose: set to True to display a progress bar
return_loss: set to True to return loss
label_name: set this to change the name of the label type that is predicted
embedding_storage_mode: default is 'none' which is the best is most cases.
Only set to 'cpu' or 'gpu' if you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 'gpu' to store embeddings in GPU memory.
Returns: None. If return_loss is set, returns a scalar tensor
"""
if label_name is None:
label_name = self.label_type

with torch.no_grad():
# make sure it's a list
if not isinstance(sentences, list):
sentences = [sentences]

Sentence.set_context_for_sentences(cast(List[Sentence], sentences))

# filter empty sentences
sentences = [sentence for sentence in sentences if len(sentence) > 0]
if len(sentences) == 0:
return torch.tensor(0.0, requires_grad=True, device=flair.device) if return_loss else None

dataloader = DataLoader(
dataset=FlairDatapointDataset(sentences),
batch_size=mini_batch_size,
)
# progress bar for verbosity
if verbose:
dataloader = tqdm(dataloader, desc="Batch inference")

overall_loss = torch.zeros(1, device=flair.device)
for batch in dataloader:
# stop if all sentences are empty
if not batch:
continue

sentence_tensor = self._sentences_to_tensor(batch)
reconstructed = self.forward(sentence_tensor)
loss_tensor = self._calculate_loss(reconstructed, sentence_tensor)

for sentence, loss in zip(batch, loss_tensor.tolist()):
sentence.remove_labels(label_name)
label_value = self.label_value if self.threshold is not None and loss < self.threshold else "<unk>"
sentence.add_label(typename=label_name, value=label_value, score=loss)

overall_loss += loss_tensor.sum()
store_embeddings(batch, storage_mode=embedding_storage_mode)

return overall_loss if return_loss else None

@property
def label_type(self) -> str:
return self._label_type

def _sentences_to_tensor(self, sentences: List[Sentence]) -> torch.Tensor:
self.embeddings.embed(sentences)
return torch.stack([sentence.embedding for sentence in sentences])

def _calculate_loss(self, predicted: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Return cosine similarity loss.
Args:
predicted: tensor of shape (batch_size, embedding_size)
labels: tensor of shape (batch_size, embedding_size)
Returns:
tensor of shape (batch_size)
"""
if labels.size(0) == 0:
return torch.tensor(0.0, requires_grad=True, device=flair.device)

return 1 - self.cosine_sim(predicted, labels)

def _get_state_dict(self):
"""Returns the state dictionary for this model."""
model_state = {
**super()._get_state_dict(),
"embeddings": self.embeddings.save_embeddings(use_state_dict=False),
"label_dictionary": self.label_dictionary,
"label_type": self.label_type,
"encoding_dim": self.encoding_dim,
"threshold": self.threshold,
}

return model_state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
return super()._init_model_with_state_dict(
state,
embeddings=state.get("embeddings"),
label_dictionary=state.get("label_dictionary"),
label_type=state.get("label_type"),
encoding_dim=state.get("encoding_dim"),
threshold=state.get("threshold"),
**kwargs,
)

@classmethod
def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "OneClassClassifier":
from typing import cast

return cast("OneClassClassifier", super().load(model_path=model_path))

def calculate_threshold(self, dataset: Dataset[Sentence], quantile=0.995) -> float:
"""Determine the score threshold to consider a Sentence in-class.
This implementation returns the score at which `quantile` of `dataset` will be considered in-class. Intended
for use-cases desiring high-recall.
"""

def score(sentence: Sentence) -> float:
sentence_tensor = self._sentences_to_tensor([sentence])
reconstructed = self.forward(sentence_tensor)
loss_tensor = self._calculate_loss(reconstructed, sentence_tensor)
return loss_tensor.tolist()[0]

scores = [
score(sentence)
for sentence in _iter_dataset(dataset)
if sentence.get_labels(self.label_type)[0].value == self.label_value
]
threshold = np.quantile(scores, quantile)
return threshold
44 changes: 44 additions & 0 deletions tests/models/test_one_class_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

import flair
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models.one_class_classification_model import OneClassClassifier
from flair.trainers import ModelTrainer
from tests.model_test_utils import BaseModelTest


class TestOneClassClassifier(BaseModelTest):
model_cls = OneClassClassifier
train_label_type = "topic"
training_args = {
"max_epochs": 2,
}

@pytest.fixture()
def corpus(self, tasks_base_path):
label_type = "topic"
corpus = flair.datasets.ClassificationCorpus(tasks_base_path / "imdb", label_type=label_type)
corpus._train = [x for x in corpus.train if x.get_labels(label_type)[0].value == "POSITIVE"]
return corpus

@pytest.fixture()
def embeddings(self):
return TransformerDocumentEmbeddings(model="distilbert-base-uncased", layers="-1", fine_tune=True)

@pytest.mark.integration()
def test_train_load_use_one_class_classifier(self, results_base_path, corpus, example_sentence, embeddings):
label_dict = corpus.make_label_dictionary(label_type=self.train_label_type)

model = self.model_cls(embeddings=embeddings, label_dictionary=label_dict, label_type=self.train_label_type)
trainer = ModelTrainer(model, corpus)

trainer.train(results_base_path, shuffle=False, **self.training_args)

del trainer, model, label_dict, corpus
loaded_model = self.model_cls.load(results_base_path / "final-model.pt")

loaded_model.predict(example_sentence)
loaded_model.predict([example_sentence, self.empty_sentence])
loaded_model.predict([self.empty_sentence])

assert example_sentence.get_labels(self.train_label_type)[0].value in {"POSITIVE", "<unk>"}

0 comments on commit 65459d3

Please sign in to comment.