Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored CrossEncoder into our own wrapper class to support head training #88

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions autointent/_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._nli_transformer import NLITransformer

__all__ = ["NLITransformer"]
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""CrossEncoderWithLogreg class for cross-encoder-based binary classification with logistic regression."""
"""NLITransformer class for cross-encoder-based estimation of meaning closeness.

Can be used to rank retrieved sentences by meaning closeness to provided utterance.
"""

import itertools as it
import logging
from pathlib import Path
from random import shuffle
from typing import Any, TypeVar
from typing import Any

import joblib
import numpy as np
import numpy.typing as npt
import torch
from sentence_transformers import CrossEncoder
from sklearn.linear_model import LogisticRegressionCV
from torch import nn

from autointent.custom_types import LabelType

Expand Down Expand Up @@ -54,15 +58,13 @@ def construct_samples(
return pairs, labels


CrossEncoderType = TypeVar("CrossEncoderType", bound="CrossEncoderWithLogreg")


class CrossEncoderWithLogreg:
class NLITransformer:
r"""
Cross-encoder with logistic regression for binary classification.
Cross-encoder for NLI.

This class uses a SentenceTransformers CrossEncoder model to extract features
and LogisticRegressionCV for classification.
In the hart this class uses a SentenceTransformers CrossEncoder model to extract features.
Then it uses either the model's clissifier or our custom trained LogisticRegressionCV
(custom classifier layer in the future) to rank documents using similarity score to the query.

:ivar cross_encoder: The CrossEncoder model used to extract features.
:ivar batch_size: Batch size for processing text pairs.
Expand All @@ -72,10 +74,8 @@ class CrossEncoderWithLogreg:
Examples
--------
Creating and fitting the CrossEncoderWithLogreg:
>>> from autointent.modules import CrossEncoderWithLogreg
>>> from sentence_transformers import CrossEncoder
>>> model = CrossEncoder("cross-encoder-model")
>>> scorer = CrossEncoderWithLogreg(model)
>>> from autointent._transformers import NLITransformer
>>> scorer = NLITransformer("cross-encoder-model")
>>> utterances = ["What is your name?", "How old are you?"]
>>> labels = [1, 0]
>>> scorer.fit(utterances, labels)
Expand All @@ -87,18 +87,41 @@ class CrossEncoderWithLogreg:

Saving and loading the model:
>>> scorer.save("outputs/")
>>> loaded_scorer = CrossEncoderWithLogreg.load("outputs/")
>>> loaded_scorer = NLITransformer.load("outputs/")
"""

def __init__(self, model: CrossEncoder, batch_size: int = 326) -> None:
"""
Initialize the CrossEncoderWithLogreg.

:param model: The CrossEncoder model to use.
def __init__(
self,
model: str,
device: str = "cpu",
train_classifier: bool = False,
batch_size: int = 326,
max_length: int | None = None,
classifier_head: LogisticRegressionCV | None = None,
) -> None:
"""
Initialize the NLITransformer.

:param model: The CrossEncoder model name to use.
:param device: Device to run operations on, e.g., "cpu" or "cuda".
:param train_classifier: Whether to train a custom classifier, defaults to False.
:param batch_size: Batch size for processing text pairs, defaults to 326.
:param max_length (int, optional): Max length for input sequences for the cross encoder.
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
"""
self.cross_encoder = model
self.cross_encoder = CrossEncoder(model, trust_remote_code=True, device=device, max_length=max_length) # type: ignore[arg-type]
self.train_classifier = False
self.batch_size = batch_size
self.max_length = max_length
self._clf = classifier_head

if classifier_head is not None or train_classifier:
self.train_classifier = True
self._logits_list: list[npt.NDArray[Any]] = []
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)

def _classifier_hook(self, _module, input_tensor, _output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ANN001
self._logits_list.append(input_tensor[0].cpu().numpy())

@torch.no_grad()
def get_features(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
Expand All @@ -108,20 +131,15 @@ def get_features(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
:param pairs: List of text pairs.
:return: Numpy array of extracted features.
"""
logits_list: list[npt.NDArray[Any]] = []
if not self.train_classifier:
return np.array(self.cross_encoder.predict(pairs, batch_size=self.batch_size, activation_fct=nn.Sigmoid()))

def hook_function(module, input_tensor, output_tensor) -> None: # type: ignore[no-untyped-def] # noqa: ARG001, ANN001
logits_list.append(input_tensor[0].cpu().numpy())
# put the data through, features will be taken in the hook
self.cross_encoder.predict(pairs, batch_size=self.batch_size)

handler = self.cross_encoder.model.classifier.register_forward_hook(hook_function)

for i in range(0, len(pairs), self.batch_size):
batch = pairs[i : i + self.batch_size]
self.cross_encoder.predict(batch)

handler.remove()

return np.concatenate(logits_list, axis=0)
res = self._logits_list
self._logits_list = []
return np.concatenate(res, axis=0)

def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:
"""
Expand All @@ -139,6 +157,8 @@ def _fit(self, pairs: list[list[str]], labels: list[LabelType]) -> None:

features = self.get_features(pairs)

# TODO: LogisticRegressionCV has class_weight="balanced". Is it better to use it instead of balance_factor in
# construct_samples?
clf = LogisticRegressionCV()
clf.fit(features, labels)

Expand All @@ -151,6 +171,9 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
:param utterances: List of utterances (texts).
:param labels: Intent class labels corresponding to the utterances.
"""
if not self.train_classifier:
return # do nothing if the classifier is not to be re-trained

pairs, labels_ = construct_samples(utterances, labels, balancing_factor=1)
self._fit(pairs, labels_) # type: ignore[arg-type]

Expand All @@ -161,8 +184,40 @@ def predict(self, pairs: list[list[str]]) -> npt.NDArray[Any]:
:param pairs: List of text pairs to classify.
:return: Numpy array of probabilities.
"""
if self.train_classifier and self._clf is None:
msg = "Classifier is not trained yet"
raise ValueError(msg)

features = self.get_features(pairs)
return self._clf.predict_proba(features)[:, 1] # type: ignore[no-any-return]

if self._clf is not None:
return np.array(self._clf.predict_proba(features)[:, 1])

return features

def rank(
self,
query: str,
query_docs: list[str],
top_k: int | None = None,
) -> list[dict[str, Any]]:
"""
Rank documents according to meaning closeness to the query.

:param query: The reference document.
:query_docs: List of documents to rank
:top_k: how many document to return
:return: array of dictionaries of ranked items.
"""
query_doc_pairs = [[query, doc] for doc in query_docs]
scores = self.predict(query_doc_pairs)

if top_k is None:
top_k = len(query_docs)

results = [{"corpus_id": i, "score": scores[i]} for i in range(len(query_docs))]
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]

def save(self, path: str) -> None:
"""
Expand All @@ -178,21 +233,13 @@ def save(self, path: str) -> None:
clf_path = dump_dir / "classifier.joblib"
joblib.dump(self._clf, clf_path)

def set_classifier(self, clf: LogisticRegressionCV) -> None:
"""
Set the logistic regression classifier.

:param clf: LogisticRegressionCV instance.
"""
self._clf = clf

@classmethod
def load(cls, path: str) -> "CrossEncoderWithLogreg":
def load(cls, path: str) -> "NLITransformer":
"""
Load the model and classifier from disk.

:param path: Directory path containing the saved model and classifier.
:return: Initialized CrossEncoderWithLogreg instance.
:return: Initialized NLITransformer instance.
"""
dump_dir = Path(path)

Expand All @@ -202,9 +249,5 @@ def load(cls, path: str) -> "CrossEncoderWithLogreg":

# Load sentence transformer model
crossencoder_dir = str(dump_dir / "crossencoder")
model = CrossEncoder(crossencoder_dir)

res = cls(model)
res.set_classifier(clf)

return res
return cls(crossencoder_dir, classifier_head=clf)
27 changes: 17 additions & 10 deletions autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,16 @@ class DataHandler:
"""Data handler class."""

def __init__(
self,
dataset: Dataset,
force_multilabel: bool = False,
random_seed: int = 0,
self, dataset: Dataset, force_multilabel: bool = False, random_seed: int = 0, split_train: bool = True
) -> None:
"""
Initialize the data handler.

:param dataset: Training dataset.
:param force_multilabel: If True, force the dataset to be multilabel.
:param random_seed: Seed for random number generation.
:param split_train: Perform or not splitting of train (default to split to be used in scoring and
threshold search).
"""
set_seed(random_seed)

Expand All @@ -50,7 +49,7 @@ def __init__(

self.n_classes = self.dataset.n_classes

self._split(random_seed)
self._split(random_seed, split_train)

self.regexp_patterns = [
RegexPatterns(
Expand Down Expand Up @@ -191,11 +190,11 @@ def dump(self, filepath: str | Path) -> None:
"""
self.dataset.to_json(filepath)

def _split(self, random_seed: int) -> None:
def _split(self, random_seed: int, split_train: bool) -> None:
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)
has_test_split = any(split.startswith(Split.TEST) for split in self.dataset)

if Split.TRAIN in self.dataset:
if split_train and Split.TRAIN in self.dataset:
self._split_train(random_seed)

if Split.TEST not in self.dataset:
Expand Down Expand Up @@ -252,13 +251,21 @@ def _split_validation_from_test(self, random_seed: int) -> None:
)

def _split_validation_from_train(self, random_seed: int) -> None:
for idx in range(2):
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
if Split.TRAIN in self.dataset:
self.dataset[Split.TRAIN], self.dataset[Split.VALIDATION] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_{idx}",
split=Split.TRAIN,
test_size=0.2,
random_seed=random_seed,
)
else:
for idx in range(2):
self.dataset[f"{Split.TRAIN}_{idx}"], self.dataset[f"{Split.VALIDATION}_{idx}"] = split_dataset(
self.dataset,
split=f"{Split.TRAIN}_{idx}",
test_size=0.2,
random_seed=random_seed,
)

def _split_test(self, test_size: float, random_seed: int) -> None:
self.dataset[f"{Split.TRAIN}_0"], self.dataset[f"{Split.TEST}_0"] = split_dataset(
Expand Down
25 changes: 8 additions & 17 deletions autointent/modules/scoring/_dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@

import numpy as np
import numpy.typing as npt
from sentence_transformers import CrossEncoder

from autointent import Context
from autointent._transformers import NLITransformer
from autointent.context.vector_index_client import VectorIndexClient, get_db_dir
from autointent.custom_types import BaseMetadataDict, LabelType
from autointent.modules.abc import ScoringModule

from .head_training import CrossEncoderWithLogreg

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -81,8 +79,8 @@ class DNNCScorer(ScoringModule):

.. testoutput::

[[-8.90408421 0. ]
[-8.10923195 0. ]]
[[0.00013581 0. ]
[0.00030066 0. ]]

.. testcleanup::

Expand All @@ -94,7 +92,7 @@ class DNNCScorer(ScoringModule):
name = "dnnc"

crossencoder_subdir: str = "crossencoder"
model: CrossEncoder | CrossEncoderWithLogreg
model: NLITransformer
prebuilt_index: bool = False

def __init__(
Expand Down Expand Up @@ -192,8 +190,6 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
"""
self.n_classes = len(set(labels))

self.model = CrossEncoder(self.cross_encoder_name, trust_remote_code=True, device=self.device)

vector_index_client = VectorIndexClient(self.device, self.db_dir, embedder_use_cache=self.embedder_use_cache)

if self.prebuilt_index:
Expand All @@ -205,10 +201,8 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
else:
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)

if self.train_head:
model = CrossEncoderWithLogreg(self.model)
model.fit(utterances, labels)
self.model = model
self.model = NLITransformer(self.cross_encoder_name, train_classifier=self.train_head, device=self.device)
self.model.fit(utterances, labels)

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
"""
Expand Down Expand Up @@ -256,7 +250,7 @@ def _get_cross_encoder_scores(self, utterances: list[str], candidates: list[list
logger.error(msg)
raise ValueError(msg)

flattened_cross_encoder_scores: npt.NDArray[np.float64] = self.model.predict(flattened_text_pairs) # type: ignore[assignment]
flattened_cross_encoder_scores: npt.NDArray[np.float64] = self.model.predict(flattened_text_pairs)
return [
flattened_cross_encoder_scores[i : i + self.k].tolist() # type: ignore[misc]
for i in range(0, len(flattened_cross_encoder_scores), self.k)
Expand Down Expand Up @@ -322,10 +316,7 @@ def load(self, path: str) -> None:
self.vector_index = vector_index_client.get_index(self.embedder_name)

crossencoder_dir = str(dump_dir / self.crossencoder_subdir)
if self.train_head:
self.model = CrossEncoderWithLogreg.load(crossencoder_dir)
else:
self.model = CrossEncoder(crossencoder_dir, device=self.device)
self.model = NLITransformer.load(crossencoder_dir)

def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]], list[list[float]]]:
"""
Expand Down
Loading
Loading