Skip to content

Commit

Permalink
enable quote-annotations for flake8 and refactor accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
mrjleo committed Dec 11, 2024
1 parent 696d80a commit 49934e1
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 57 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,6 @@ ignore = [
"PLR0913", # too-many-arguments
"PLR2004", # magic-value-comparison
]

[tool.ruff.lint.flake8-type-checking]
quote-annotations = true
24 changes: 13 additions & 11 deletions src/fast_forward/encoder.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
""".. include:: docs/encoder.md""" # noqa: D400, D415

import abc
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from pathlib import Path


class Encoder(abc.ABC):
"""Base class for encoders."""

@abc.abstractmethod
def _encode(self, texts: Sequence[str]) -> np.ndarray:
def _encode(self, texts: "Sequence[str]") -> np.ndarray:
pass

def __call__(self, texts: Sequence[str]) -> np.ndarray:
def __call__(self, texts: "Sequence[str]") -> np.ndarray:
"""Encode a list of texts.
:param texts: The texts to encode.
Expand All @@ -30,7 +32,7 @@ class TransformerEncoder(Encoder):
"""Uses a pre-trained transformer model for encoding. Returns the pooler output."""

def __init__(
self, model: str | Path, device: str = "cpu", **tokenizer_args: Any
self, model: "str | Path", device: str = "cpu", **tokenizer_args: Any
) -> None:
"""Create a transformer encoder.
Expand All @@ -46,7 +48,7 @@ def __init__(
self.device = device
self.tokenizer_args = tokenizer_args

def _encode(self, texts: Sequence[str]) -> np.ndarray:
def _encode(self, texts: "Sequence[str]") -> np.ndarray:
inputs = self.tokenizer(list(texts), return_tensors="pt", **self.tokenizer_args)
inputs.to(self.device)
with torch.no_grad():
Expand All @@ -56,15 +58,15 @@ def _encode(self, texts: Sequence[str]) -> np.ndarray:
class LambdaEncoder(Encoder):
"""Encoder adapter class for arbitrary encoding functions."""

def __init__(self, f: Callable[[str], np.ndarray]) -> None:
def __init__(self, f: "Callable[[str], np.ndarray]") -> None:
"""Create a lambda encoder.
:param f: Function to encode a single piece of text.
"""
super().__init__()
self._f = f

def _encode(self, texts: Sequence[str]) -> np.ndarray:
def _encode(self, texts: "Sequence[str]") -> np.ndarray:
return np.array(list(map(self._f, texts)))


Expand All @@ -75,7 +77,7 @@ class TCTColBERTQueryEncoder(TransformerEncoder):
https://github.com/castorini/pyserini/blob/310c828211bb3b9528cfd59695184c80825684a2/pyserini/encode/_tct_colbert.py#L72
"""

def _encode(self, texts: Sequence[str]) -> np.ndarray:
def _encode(self, texts: "Sequence[str]") -> np.ndarray:
max_length = 36
inputs = self.tokenizer(
["[CLS] [Q] " + q + "[MASK]" * max_length for q in texts],
Expand All @@ -98,7 +100,7 @@ class TCTColBERTDocumentEncoder(TransformerEncoder):
https://github.com/castorini/pyserini/blob/310c828211bb3b9528cfd59695184c80825684a2/pyserini/encode/_tct_colbert.py#L27
"""

def _encode(self, texts: Sequence[str]) -> np.ndarray:
def _encode(self, texts: "Sequence[str]") -> np.ndarray:
max_length = 512
inputs = self.tokenizer(
["[CLS] [D] " + text for text in texts],
Expand Down
23 changes: 13 additions & 10 deletions src/fast_forward/index/disk.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import logging
from collections import defaultdict
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import cast
from typing import TYPE_CHECKING, cast

import h5py
import numpy as np
from tqdm import tqdm

import fast_forward
from fast_forward.encoder import Encoder
from fast_forward.index import IDSequence, Index, Mode
from fast_forward.index.memory import InMemoryIndex
from fast_forward.quantizer import Quantizer

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from pathlib import Path

from fast_forward.encoder import Encoder

LOGGER = logging.getLogger(__name__)


Expand All @@ -25,8 +28,8 @@ class OnDiskIndex(Index):

def __init__(
self,
index_file: Path,
query_encoder: Encoder | None = None,
index_file: "Path",
query_encoder: "Encoder | None" = None,
quantizer: Quantizer | None = None,
mode: Mode = Mode.MAXP,
encoder_batch_size: int = 32,
Expand Down Expand Up @@ -239,7 +242,7 @@ def _get_doc_ids(self) -> set[str]:
def _get_psg_ids(self) -> set[str]:
return set(self._psg_id_to_idx.keys())

def _get_vectors(self, ids: Iterable[str]) -> tuple[np.ndarray, list[list[int]]]:
def _get_vectors(self, ids: "Iterable[str]") -> tuple[np.ndarray, list[list[int]]]:
idx_pairs = []
with h5py.File(self._index_file, "r") as fp:
for id in ids:
Expand Down Expand Up @@ -276,7 +279,7 @@ def _get_vectors(self, ids: Iterable[str]) -> tuple[np.ndarray, list[list[int]]]

def _batch_iter(
self, batch_size: int
) -> Iterator[tuple[np.ndarray, IDSequence, IDSequence]]:
) -> "Iterator[tuple[np.ndarray, IDSequence, IDSequence]]":
with h5py.File(self._index_file, "r") as fp:
num_vectors = cast(int, fp.attrs["num_vectors"])
for i in range(0, num_vectors, batch_size):
Expand All @@ -294,8 +297,8 @@ def _batch_iter(
@classmethod
def load(
cls,
index_file: Path,
query_encoder: Encoder | None = None,
index_file: "Path",
query_encoder: "Encoder | None" = None,
mode: Mode = Mode.MAXP,
encoder_batch_size: int = 32,
resize_min_val: int = 2**10,
Expand Down
18 changes: 11 additions & 7 deletions src/fast_forward/index/memory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging
from collections import defaultdict
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

from fast_forward.encoder import Encoder
from fast_forward.index import IDSequence, Index, Mode
from fast_forward.quantizer import Quantizer

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

from fast_forward.encoder import Encoder
from fast_forward.quantizer import Quantizer

LOGGER = logging.getLogger(__name__)

Expand All @@ -17,8 +21,8 @@ class InMemoryIndex(Index):

def __init__(
self,
query_encoder: Encoder | None = None,
quantizer: Quantizer | None = None,
query_encoder: "Encoder | None" = None,
quantizer: "Quantizer | None" = None,
mode: Mode = Mode.MAXP,
encoder_batch_size: int = 32,
init_size: int = 2**14,
Expand Down Expand Up @@ -145,7 +149,7 @@ def _index_shards(self, idx: int) -> tuple[int, int]:
idx_in_shard = idx % self._alloc_size
return shard_idx, idx_in_shard

def _get_vectors(self, ids: Iterable[str]) -> tuple[np.ndarray, list[list[int]]]:
def _get_vectors(self, ids: "Iterable[str]") -> tuple[np.ndarray, list[list[int]]]:
items_by_shard = defaultdict(list)
for id in ids:
if self.mode in (Mode.MAXP, Mode.AVEP) and id in self._doc_id_to_idx:
Expand Down Expand Up @@ -178,7 +182,7 @@ def _get_vectors(self, ids: Iterable[str]) -> tuple[np.ndarray, list[list[int]]]

def _batch_iter(
self, batch_size: int
) -> Iterator[tuple[np.ndarray, IDSequence, IDSequence]]:
) -> "Iterator[tuple[np.ndarray, IDSequence, IDSequence]]":
LOGGER.info("creating ID mappings for this index")
idx_to_doc_id = {
idx: doc_id
Expand Down
28 changes: 15 additions & 13 deletions src/fast_forward/indexer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
""".. include:: docs/indexer.md""" # noqa: D400, D415

import logging
from collections.abc import Iterable, Sequence
from typing import TypedDict
from typing import TYPE_CHECKING, TypedDict

import numpy as np
from tqdm import tqdm

from fast_forward.encoder import Encoder
from fast_forward.index import IDSequence, Index
from fast_forward.quantizer import Quantizer
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence

from fast_forward.encoder import Encoder
from fast_forward.index import IDSequence, Index
from fast_forward.quantizer import Quantizer

LOGGER = logging.getLogger(__name__)

Expand All @@ -27,11 +29,11 @@ class Indexer:

def __init__(
self,
index: Index,
encoder: Encoder | None = None,
index: "Index",
encoder: "Encoder | None" = None,
encoder_batch_size: int = 128,
batch_size: int = 2**16,
quantizer: Quantizer | None = None,
quantizer: "Quantizer | None" = None,
quantizer_fit_batches: int = 1,
) -> None:
"""Instantiate an indexer.
Expand Down Expand Up @@ -79,8 +81,8 @@ def __init__(
def _index_batch(
self,
vectors: np.ndarray,
doc_ids: IDSequence | None = None,
psg_ids: IDSequence | None = None,
doc_ids: "IDSequence | None" = None,
psg_ids: "IDSequence | None" = None,
) -> None:
"""Add a batch to the index.
Expand Down Expand Up @@ -129,7 +131,7 @@ def _index_batch(
del self._buf_doc_ids
del self._buf_psg_ids

def _encode(self, texts: Sequence[str]) -> np.ndarray:
def _encode(self, texts: "Sequence[str]") -> np.ndarray:
"""Encode a list of strings (respecting the encoder batch size).
:param texts: The pieces of text to encode.
Expand All @@ -145,7 +147,7 @@ def _encode(self, texts: Sequence[str]) -> np.ndarray:
result.append(self._encoder(batch))
return np.concatenate(result)

def from_dicts(self, data: Iterable[IndexingDict]) -> None:
def from_dicts(self, data: "Iterable[IndexingDict]") -> None:
"""Index data from dictionaries.
:param data: An iterable of the dictionaries.
Expand All @@ -163,7 +165,7 @@ def from_dicts(self, data: Iterable[IndexingDict]) -> None:
if len(texts) > 0:
self._index_batch(self._encode(texts), doc_ids=doc_ids, psg_ids=psg_ids)

def from_index(self, index: Index) -> None:
def from_index(self, index: "Index") -> None:
"""Transfer vectors and IDs from another index.
If the source index uses quantized representations, the vectors are
Expand Down
9 changes: 6 additions & 3 deletions src/fast_forward/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import logging
from collections.abc import Iterator, Mapping
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

if TYPE_CHECKING:
from pathlib import Path

LOGGER = logging.getLogger(__name__)

Run = Mapping[str, Mapping[str, float]]
Expand Down Expand Up @@ -335,7 +338,7 @@ def rr_scores(self, k: int = 60) -> "Ranking":

def save(
self,
target: Path,
target: "Path",
) -> None:
"""Save the ranking in a TREC runfile.
Expand Down Expand Up @@ -376,7 +379,7 @@ def from_run(
@classmethod
def from_file(
cls,
f: Path,
f: "Path",
queries: Mapping[str, str] | None = None,
dtype: np.dtype = np.dtype(np.float32),
) -> "Ranking":
Expand Down
21 changes: 13 additions & 8 deletions src/fast_forward/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
""".. include:: ../docs/util.md""" # noqa: D400, D415

from collections.abc import Callable

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from tqdm import tqdm

from fast_forward.index import Index
from fast_forward.ranking import Ranking
if TYPE_CHECKING:
from collections.abc import Callable

import pandas as pd

from fast_forward.index import Index
from fast_forward.ranking import Ranking


def to_ir_measures(ranking: Ranking) -> pd.DataFrame:
def to_ir_measures(ranking: "Ranking") -> "pd.DataFrame":
"""Return a ranking as a data frame suitable for the ir-measures library.
:param ranking: The input ranking.
Expand All @@ -33,10 +38,10 @@ def cos_dist(a: np.ndarray, b: np.ndarray) -> float:


def create_coalesced_index(
source_index: Index,
target_index: Index,
source_index: "Index",
target_index: "Index",
delta: float,
distance_function: Callable[[np.ndarray, np.ndarray], float] = cos_dist,
distance_function: "Callable[[np.ndarray, np.ndarray], float]" = cos_dist,
batch_size: int | None = None,
) -> None:
"""Create a compressed index using sequential coalescing.
Expand Down
Loading

0 comments on commit 49934e1

Please sign in to comment.