Skip to content

Commit

Permalink
Merge pull request #11 from mrjleo/encoders
Browse files Browse the repository at this point in the history
Encoders
  • Loading branch information
mrjleo authored Dec 19, 2024
2 parents dc91095 + ce1da46 commit 0cfb393
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ useLibraryCodeForTypes = true
reportMissingParameterType = true

[tool.ruff]
extend-exclude = ["tests"]
line-length = 88

[tool.ruff.lint]
exclude = ["tests/*"]
select = [
"F", # Pyflakes
"E", # pycodestyle
Expand Down
6 changes: 6 additions & 0 deletions src/fast_forward/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from fast_forward.encoder.base import Encoder
from fast_forward.encoder.transformer import (
BGEEncoder,
ContrieverEncoder,
TASBEncoder,
TCTColBERTDocumentEncoder,
TCTColBERTQueryEncoder,
TransformerEncoder,
Expand All @@ -20,6 +23,9 @@
"TransformerEncoder",
"TCTColBERTQueryEncoder",
"TCTColBERTDocumentEncoder",
"TASBEncoder",
"ContrieverEncoder",
"BGEEncoder",
]


Expand Down
271 changes: 216 additions & 55 deletions src/fast_forward/encoder/transformer.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,258 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

from fast_forward.encoder.base import Encoder

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from pathlib import Path

import numpy as np
from transformers import BatchEncoding
from transformers.modeling_outputs import BaseModelOutput


class TransformerEncoder(Encoder):
"""Uses a pre-trained transformer model for encoding. Returns the pooler output."""
"""Uses a pre-trained Transformer model for encoding.
The outputs corresponding to the CLS token from the last hidden layer are used.
"""

def __init__(
self, model: "str | Path", device: str = "cpu", **tokenizer_args: Any
self,
model: "str | Path",
device: str = "cpu",
model_args: "Mapping[str, Any]" = {},
tokenizer_args: "Mapping[str, Any]" = {},
tokenizer_call_args: "Mapping[str, Any]" = {
"padding": True,
"truncation": True,
},
normalize: bool = False,
) -> None:
"""Create a transformer encoder.
"""Create a Transformer encoder.
:param model: Pre-trained transformer model (name or path).
:param model: Pre-trained Transformer model (name or path).
:param device: PyTorch device.
:param **tokenizer_args: Additional tokenizer arguments.
:param model_args: Additional arguments for the model.
:param tokenizer_args: Additional arguments for the tokenizer.
:param tokenizer_call_args: Additional arguments for the tokenizer call.
:param normalize: L2-normalize output representations.
"""
super().__init__()
self._model = AutoModel.from_pretrained(model)
self._model = AutoModel.from_pretrained(model, **model_args)
self._model.to(device)
self._model.eval()
self._tokenizer = AutoTokenizer.from_pretrained(model)
self._tokenizer = AutoTokenizer.from_pretrained(model, **tokenizer_args)
self._device = device
self._tokenizer_args = tokenizer_args
self._tokenizer_call_args = tokenizer_call_args
self._normalize = normalize

def _get_tokenizer_inputs(self, texts: "Sequence[str]") -> list[str]:
"""Prepare input texts for tokenization.
:param texts: The texts to encode.
:return: The tokenizer inputs.
"""
return list(texts)

def _aggregate_model_outputs(
self,
model_outputs: "BaseModelOutput",
model_inputs: "BatchEncoding", # noqa: ARG002
) -> torch.Tensor:
"""Aggregate Transformer outputs using the CLS token (last hidden state).
Encoders overriding this function may make use of `model_inputs`.
:param model_outputs: The Transformer outputs.
:param model_inputs: The Transformer inputs (unused).
:return: The CLS token representations from the last hidden state.
"""
return model_outputs.last_hidden_state[:, 0]

def _encode(self, texts: "Sequence[str]") -> "np.ndarray":
model_inputs = self._tokenizer(
self._get_tokenizer_inputs(texts),
return_tensors="pt",
**self._tokenizer_call_args,
).to(self._device)

def _encode(self, texts: "Sequence[str]") -> np.ndarray:
inputs = self._tokenizer(
list(texts), return_tensors="pt", **self._tokenizer_args
)
inputs.to(self._device)
with torch.no_grad():
return self._model(**inputs).pooler_output.detach().cpu().numpy()
model_outputs = self._model(**model_inputs)
result = self._aggregate_model_outputs(model_outputs, model_inputs)
if self._normalize:
result = torch.nn.functional.normalize(result, p=2, dim=1)
return result.cpu().detach().numpy()


class TCTColBERTQueryEncoder(TransformerEncoder):
"""Query encoder for pre-trained TCT-ColBERT models.
"""Pre-trained TCT-ColBERT query encoder.
Adapted from Pyserini:
https://github.com/castorini/pyserini/blob/310c828211bb3b9528cfd59695184c80825684a2/pyserini/encode/_tct_colbert.py#L72
Corresponding paper: https://aclanthology.org/2021.repl4nlp-1.17/
"""

def _encode(self, texts: "Sequence[str]") -> np.ndarray:
max_length = 36
inputs = self._tokenizer(
["[CLS] [Q] " + q + "[MASK]" * max_length for q in texts],
max_length=max_length,
truncation=True,
add_special_tokens=False,
return_tensors="pt",
**self._tokenizer_args,
def __init__(
self,
model: "str | Path" = "castorini/tct_colbert-msmarco",
device: str = "cpu",
max_length: int = 36,
) -> None:
"""Create a TCT-ColBERT query encoder.
:param model: Pre-trained TCT-ColBERT model (name or path).
:param device: PyTorch device.
:param max_length: Maximum number of tokens per query.
"""
self._max_length = max_length
super().__init__(
model,
device=device,
tokenizer_call_args={
"max_length": max_length,
"truncation": True,
"add_special_tokens": False,
},
)
inputs.to(self._device)
with torch.no_grad():
embeddings = self._model(**inputs).last_hidden_state.detach().cpu().numpy()
return np.average(embeddings[:, 4:, :], axis=-2)

def _get_tokenizer_inputs(self, texts: "Sequence[str]") -> list[str]:
return ["[CLS] [Q] " + q + "[MASK]" * self._max_length for q in texts]

def _aggregate_model_outputs(
self,
model_outputs: "BaseModelOutput",
model_inputs: "BatchEncoding", # noqa: ARG002
) -> torch.Tensor:
embeddings = model_outputs.last_hidden_state[:, 4:, :]
return torch.mean(embeddings, dim=-2)


class TCTColBERTDocumentEncoder(TransformerEncoder):
"""Document encoder for pre-trained TCT-ColBERT models.
"""Pre-trained TCT-ColBERT document encoder.
Adapted from Pyserini:
https://github.com/castorini/pyserini/blob/310c828211bb3b9528cfd59695184c80825684a2/pyserini/encode/_tct_colbert.py#L27
Corresponding paper: https://aclanthology.org/2021.repl4nlp-1.17/
"""

def _encode(self, texts: "Sequence[str]") -> np.ndarray:
max_length = 512
inputs = self._tokenizer(
["[CLS] [D] " + text for text in texts],
max_length=max_length,
padding=True,
truncation=True,
add_special_tokens=False,
return_tensors="pt",
**self._tokenizer_args,
def __init__(
self,
model: "str | Path" = "castorini/tct_colbert-msmarco",
device: str = "cpu",
max_length: int = 512,
) -> None:
"""Create a TCT-ColBERT document encoder.
:param model: Pre-trained TCT-ColBERT model (name or path).
:param device: PyTorch device.
:param max_length: Maximum number of tokens per document.
"""
self._max_length = max_length
super().__init__(
model,
device=device,
tokenizer_call_args={
"max_length": max_length,
"padding": True,
"truncation": True,
"add_special_tokens": False,
},
)
inputs.to(self._device)
with torch.no_grad():
outputs = self._model(**inputs)
token_embeddings = outputs["last_hidden_state"][:, 4:, :]
input_mask_expanded = (
inputs.attention_mask[:, 4:]
.unsqueeze(-1)
.expand(token_embeddings.size())
.float()
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embeddings = sum_embeddings / sum_mask
return embeddings.detach().cpu().numpy()

def _get_tokenizer_inputs(self, texts: "Sequence[str]") -> list[str]:
return ["[CLS] [D] " + d for d in texts]

def _aggregate_model_outputs(
self,
model_outputs: "BaseModelOutput",
model_inputs: "BatchEncoding",
) -> torch.Tensor:
token_embeddings = model_outputs.last_hidden_state[:, 4:, :]
input_mask_expanded = (
model_inputs.attention_mask[:, 4:]
.unsqueeze(-1)
.expand(token_embeddings.size())
.float()
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask


class TASBEncoder(TransformerEncoder):
"""Pre-trained TAS-B (topic-aware sampling) encoder.
Corresponding paper: https://dl.acm.org/doi/10.1145/3404835.3462891
"""

def __init__(
self,
model: "str | Path" = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco",
device: str = "cpu",
) -> None:
"""Create a TAS-B encoder.
:param model: Pre-trained TAS-B model (name or path).
:param device: PyTorch device.
"""
# TAS-B uses CLS-pooling (TransformerEncoder default)
super().__init__(model, device=device)


class ContrieverEncoder(TransformerEncoder):
"""Pre-trained Contriever encoder.
Adapted from: https://huggingface.co/facebook/contriever
Corresponding paper: https://openreview.net/forum?id=jKN1pXi7b0
"""

def __init__(
self,
model: "str | Path" = "facebook/contriever",
device: str = "cpu",
) -> None:
"""Create a Contriever encoder.
:param model: Pre-trained Contriever model (name or path).
:param device: PyTorch device.
"""
super().__init__(model, device=device)

def _aggregate_model_outputs(
self,
model_outputs: "BaseModelOutput",
model_inputs: "BatchEncoding",
) -> torch.Tensor:
token_embeddings = model_outputs[0].masked_fill(
~model_inputs.attention_mask[..., None].bool(), 0.0
)
return (
token_embeddings.sum(dim=1)
/ model_inputs.attention_mask.sum(dim=1)[..., None]
)


class BGEEncoder(TransformerEncoder):
"""Pre-trained BGE encoder.
Corresponding paper: https://dl.acm.org/doi/10.1145/3626772.3657878
"""

def __init__(
self,
model: "str | Path" = "BAAI/bge-base-en-v1.5",
device: str = "cpu",
) -> None:
"""Create a BGE encoder.
:param model: Pre-trained BGE model (name or path).
:param device: PyTorch device.
"""
super().__init__(model, device=device, normalize=True)
14 changes: 14 additions & 0 deletions tests/_constants.py

Large diffs are not rendered by default.

Loading

0 comments on commit 0cfb393

Please sign in to comment.