-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from mrjleo/api
Organize API
- Loading branch information
Showing
30 changed files
with
914 additions
and
853 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,11 @@ | ||
""".. include:: docs/main.md""" # noqa: D400, D415 | ||
""".. include:: docs/main.md | ||
.. include:: docs/ranking.md | ||
""" # noqa: D205, D400, D415 | ||
|
||
import importlib.metadata | ||
|
||
# in this specific case, the redundant aliases are recommended by pyright | ||
# ruff: noqa: PLC0414 | ||
from fast_forward.index import Mode as Mode | ||
from fast_forward.index.disk import OnDiskIndex as OnDiskIndex | ||
from fast_forward.index.memory import InMemoryIndex as InMemoryIndex | ||
from fast_forward.indexer import Indexer as Indexer | ||
from fast_forward.quantizer.nanopq import NanoOPQ as NanoOPQ | ||
from fast_forward.quantizer.nanopq import NanoPQ as NanoPQ | ||
from fast_forward.ranking import Ranking as Ranking | ||
from fast_forward import encoder, index, quantizer, util | ||
from fast_forward.ranking import Ranking | ||
|
||
__all__ = ["encoder", "index", "quantizer", "util", "Ranking"] | ||
__version__ = importlib.metadata.version("fast-forward-indexes") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# PyTerrier transformers | ||
|
||
Fast-Forward indexes can seamlessly be integrated into [PyTerrier](https://pyterrier.readthedocs.io/en/latest/) pipelines using the transformers provided in `fast_forward.util.pyterrier`. Specifically, a re-ranking pipeline might look like this, given that `my_index` is a Fast-Forward index of the MS MARCO passage corpus: | ||
|
||
```python | ||
bm25 = pt.BatchRetrieve.from_dataset( | ||
"msmarco_passage", variant="terrier_stemmed", wmodel="BM25" | ||
) | ||
|
||
ff_pl = bm25 % 5000 >> FFScore(my_index) >> FFInterpolate(0.2) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""".. include:: ../docs/encoder.md""" # noqa: D400, D415 | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
from fast_forward.encoder.base import Encoder | ||
from fast_forward.encoder.transformer import ( | ||
TCTColBERTDocumentEncoder, | ||
TCTColBERTQueryEncoder, | ||
TransformerEncoder, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable, Sequence | ||
|
||
__all__ = [ | ||
"Encoder", | ||
"LambdaEncoder", | ||
"TransformerEncoder", | ||
"TCTColBERTQueryEncoder", | ||
"TCTColBERTDocumentEncoder", | ||
] | ||
|
||
|
||
class LambdaEncoder(Encoder): | ||
"""Encoder adapter class for arbitrary encoding functions.""" | ||
|
||
def __init__(self, f: "Callable[[str], np.ndarray]") -> None: | ||
"""Create a lambda encoder. | ||
:param f: Function to encode a single piece of text. | ||
""" | ||
super().__init__() | ||
self._f = f | ||
|
||
def _encode(self, texts: "Sequence[str]") -> np.ndarray: | ||
return np.array(list(map(self._f, texts))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import abc | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Sequence | ||
|
||
import numpy as np | ||
|
||
|
||
class Encoder(abc.ABC): | ||
"""Base class for encoders.""" | ||
|
||
@abc.abstractmethod | ||
def _encode(self, texts: "Sequence[str]") -> "np.ndarray": | ||
pass | ||
|
||
def __call__(self, texts: "Sequence[str]") -> "np.ndarray": | ||
"""Encode a list of texts. | ||
:param texts: The texts to encode. | ||
:return: The resulting vector representations. | ||
""" | ||
return self._encode(texts) |
Oops, something went wrong.