Skip to content

Commit

Permalink
fix: tiny modification for embed and rank
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Nov 1, 2024
1 parent 1c36ed1 commit e30fd01
Show file tree
Hide file tree
Showing 22 changed files with 223 additions and 83 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-13] # add windows-2019 when poetry allows installation with `-f` flag
python-version: [3.8, 3.9, '3.10']
python-version: [3.8, '3.10', '3.12']

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -62,15 +62,15 @@ jobs:

steps:
- name: Check out Git repository
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: 3.11

- name: Cache pip
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('docs/requirements_docs.txt') }}
Expand All @@ -91,7 +91,7 @@ jobs:
make html --debug --jobs 2 SPHINXOPTS="-W"
- name: Upload built docs
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: docs-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}
path: docs/build/html/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ print(scores.tolist())
```
</details>

<details><summary> Index building for dense retrieval search </summary>
<details><summary> Faiss retrieval search </summary>

```python
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval
Expand Down
4 changes: 4 additions & 0 deletions docs/source/embed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,7 @@ cosent loss

Sampling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


4. Embedding serving
----------------------------------------------
19 changes: 17 additions & 2 deletions docs/source/rerank.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Rerank
Ranking score: [5.445939064025879, 3.0762712955474854]
**LLM reranking**
**LLM generative reranking**

.. code-block:: python
Expand All @@ -74,6 +74,10 @@ Rerank
2. Fine-tune cross-encoder reranking model
-----------------------------------------------

prepare data
`{(query1, document1, label1), (query2, document2, label2), ...}`


.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing
:alt: Open In Colab
Expand Down Expand Up @@ -125,6 +129,10 @@ Rerank
3. Fine-tune ColBERT reranking model
----------------------------------------

prepare data
`{}`


.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing
:alt: Open In Colab
Expand Down Expand Up @@ -195,9 +203,16 @@ Rerank
trainer.train()
4. Fine-tune LLM reranker
4. Fine-tune LLM Generative reranker
-------------------------------------

prepare generative reranking data
`{}`

prepare representative reranking data
`{}`


.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing
:alt: Open In Colab
Expand Down
7 changes: 7 additions & 0 deletions examples/0_embedding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ model = AutoModelForEmbedding.from_pretrained(
document_instruction='',
)
```

## Deployment

**Prerequisites**
```shell
pip install optimum
```
File renamed without changes.
1 change: 0 additions & 1 deletion examples/README_zh_CN.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Open-Retrievals 示例

```shell
# If you are behind the wall
export HF_ENDPOINT=https://hf-mirror.com
```

Expand Down
6 changes: 0 additions & 6 deletions examples/deployment/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion src/retrievals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .models.embedding_auto import AutoModelForEmbedding, ListwiseModel, PairwiseModel
from .models.pooling import AutoPooling
from .models.rerank import AutoModelForRanking, ColBERT, LLMRanker
from .models.retrieval_auto import AutoModelForRetrieval
from .models.retrieval_auto import AutoModelForRetrieval, BM25Retrieval, FaissRetrieval
from .trainer.custom_trainer import CustomTrainer
from .trainer.trainer import RerankTrainer, RetrievalTrainer
from .trainer.tuner import AutoTuner
12 changes: 9 additions & 3 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class AutoCollator(DataCollatorWithPadding):
"""Choose the collator based on data/task
TODO: combine pair, triplet, colbert into one
TODO: combine pair, triplet, colbert into one collator
"""

def __init__(self):
Expand All @@ -30,9 +30,10 @@ def __init__(
tokenizer: PreTrainedTokenizer,
query_max_length: int = 32,
document_max_length: int = 128,
append_eos_token: bool = False,
query_key: str = 'query',
document_key: str = 'positive',
append_eos_token: bool = False,
tokenize_args: Optional[Dict] = None,
) -> None:
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
Expand Down Expand Up @@ -176,9 +177,10 @@ def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int = 128,
append_eos_token: bool = False,
query_key: str = 'query',
document_key: str = 'document',
append_eos_token: bool = False,
tokenize_args: Optional[Dict] = None,
):
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
Expand Down Expand Up @@ -235,6 +237,7 @@ def __init__(
query_key: str = 'query',
positive_key: str = 'positive',
negative_key: str = 'negative',
tokenize_args: Optional[Dict] = None,
) -> None:
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
Expand Down Expand Up @@ -319,6 +322,7 @@ def __init__(
add_target_token: str = '',
sep_token: str = "\n",
max_length: int = 128,
tokenize_args: Optional[Dict] = None,
pad_to_multiple_of: Optional[int] = 8,
):
self.tokenizer = tokenizer
Expand All @@ -333,13 +337,15 @@ def __call__(self, features: List[Dict[str, Any]], return_tensors='pt'):
examples = []

if isinstance(features[0], dict):
"""explode the {(query, positive, negatives)} to pair data"""
for i in range(len(features)):
examples.append((features[i][self.query_key], features[i][self.positive_key]))
for neg in features[i][self.negative_key]:
examples.append((features[i][self.query_key], neg))
else:
examples = features

# TODO: double check the add_target_token, only yes now?
batch = self.tokenizer(
[self.bos_token + i[0] for i in examples],
[self.sep_token + i[1] + self.sep_token + self.prompt + self.add_target_token for i in examples],
Expand Down
28 changes: 27 additions & 1 deletion src/retrievals/losses/infonce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Callable, Literal, Optional, Union

import torch
import torch.distributed.nn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -24,6 +24,7 @@ def __init__(
criterion: Union[nn.Module, Callable, None] = nn.CrossEntropyLoss(label_smoothing=0.0, reduction='mean'),
temperature: float = 0.05,
use_inbatch_negative: bool = True,
negatives_cross_device: bool = False,
negative_mode: Literal['paired', 'unpaired'] = "unpaired",
**kwargs
):
Expand All @@ -35,16 +36,28 @@ def __init__(
self.criterion = criterion
self.temperature = temperature
self.use_inbatch_negative = use_inbatch_negative
self.negatives_cross_device = negatives_cross_device
self.negative_mode = negative_mode
if self.temperature > 0.5:
logger.error('InfoNCE loss use normalized and inner product by default, temperature should be 0.01 ~ 0.1')
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError("Cannot do negatives_cross_device without distributed training")
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()

def forward(
self,
query_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor,
negative_embeddings: Optional[torch.Tensor] = None,
):
if self.negatives_cross_device and self.use_inbatch_negative:
query_embeddings = self._dist_gather_tensor(query_embeddings)
positive_embeddings = self._dist_gather_tensor(positive_embeddings)
if negative_embeddings is not None:
negative_embeddings = self._dist_gather_tensor(negative_embeddings)

query_embeddings = F.normalize(query_embeddings, dim=-1)
positive_embeddings = F.normalize(positive_embeddings, dim=-1)
device = query_embeddings.device
Expand Down Expand Up @@ -82,3 +95,16 @@ def forward(
target = torch.zeros(query_embeddings.size(0), dtype=torch.long, device=device)

return self.criterion(similarity, target)

def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()

all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)

all_tensors[self.rank] = t
all_tensors = torch.cat(all_tensors, dim=0)

return all_tensors
10 changes: 5 additions & 5 deletions src/retrievals/losses/mrl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def __init__(self, criterion: nn.Module, mrl_nested_dim: List[int]):
def forward(
self,
query_embeddings: torch.Tensor,
pos_embeddings: torch.Tensor,
neg_embeddings: Optional[torch.Tensor] = None,
positive_embeddings: torch.Tensor,
negative_embeddings: Optional[torch.Tensor] = None,
):
query_mrl_embed = self.query_mrl(query_embeddings)
positive_mrl_embed = self.positive_mrl(pos_embeddings)
if neg_embeddings is not None:
negative_mrl_embed = self.negative_mrl(neg_embeddings)
positive_mrl_embed = self.positive_mrl(positive_embeddings)
if negative_embeddings is not None:
negative_mrl_embed = self.negative_mrl(negative_embeddings)
else:
negative_mrl_embed = [None] * len(positive_mrl_embed)

Expand Down
12 changes: 7 additions & 5 deletions src/retrievals/losses/simcse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ def __init__(
def forward(
self,
query_embeddings: torch.Tensor,
pos_embeddings: torch.Tensor,
neg_embeddings: Optional[torch.Tensor] = None,
positive_embeddings: torch.Tensor,
negative_embeddings: Optional[torch.Tensor] = None,
):
similarity = F.cosine_similarity(query_embeddings.unsqueeze(1), pos_embeddings.unsqueeze(0), dim=-1)
similarity = F.cosine_similarity(query_embeddings.unsqueeze(1), positive_embeddings.unsqueeze(0), dim=-1)

if neg_embeddings is not None:
neg_similarity = F.cosine_similarity(query_embeddings.unsqueeze(1), neg_embeddings.unsqueeze(0), dim=-1)
if negative_embeddings is not None:
neg_similarity = F.cosine_similarity(
query_embeddings.unsqueeze(1), negative_embeddings.unsqueeze(0), dim=-1
)
similarity = torch.cat([similarity, neg_similarity], dim=1)

similarity = similarity / self.temperature
Expand Down
19 changes: 10 additions & 9 deletions src/retrievals/losses/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def __init__(
temperature: float = 0.05,
margin: float = 0.0,
negatives_cross_device: bool = False,
batch_hard: bool = False,
use_inbatch_negative: bool = False,
**kwargs
):
super().__init__()
self.temperature = temperature
self.margin = margin
self.negatives_cross_device = negatives_cross_device
self.batch_hard = batch_hard
self.use_inbatch_negative = use_inbatch_negative
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError("Cannot do negatives_cross_device without distributed training")
Expand All @@ -38,22 +38,23 @@ def __init__(
def forward(
self,
query_embeddings: torch.Tensor,
pos_embeddings: torch.Tensor,
neg_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor,
negative_embeddings: torch.Tensor,
margin: float = 0.0,
):
if margin:
self.set_margin(margin=margin)

if self.negatives_cross_device:
pos_embeddings = self._dist_gather_tensor(pos_embeddings)
neg_embeddings = self._dist_gather_tensor(neg_embeddings)
if self.negatives_cross_device and self.use_inbatch_negative:
query_embeddings = self._dist_gather_tensor(query_embeddings)
positive_embeddings = self._dist_gather_tensor(positive_embeddings)
negative_embeddings = self._dist_gather_tensor(negative_embeddings)

pos_similarity = torch.cosine_similarity(query_embeddings, pos_embeddings, dim=-1)
pos_similarity = torch.cosine_similarity(query_embeddings, positive_embeddings, dim=-1)
pos_similarity = pos_similarity / self.temperature
neg_similarity = torch.cosine_similarity(
query_embeddings.unsqueeze(1),
neg_embeddings.unsqueeze(0),
negative_embeddings.unsqueeze(0),
dim=-1,
)
neg_similarity = neg_similarity / self.temperature
Expand Down
5 changes: 5 additions & 0 deletions src/retrievals/models/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Base model for embedding and reranking"""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -40,6 +42,9 @@ def save_pretrained(self, path: str, safe_serialization: bool = True):
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

def enable_input_require_grads(self, **kwargs):
self.model.enable_input_require_grads(**kwargs)

def resize_token_embeddings(self, new_num_tokens: Optional = None, pad_to_multiple_of: Optional = None):
# add new, random embeddings for the new tokens
self.model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
Expand Down
Loading

0 comments on commit e30fd01

Please sign in to comment.