Skip to content

Commit

Permalink
fix: embed and rerank benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Jun 27, 2024
1 parent 4a66ef3 commit 26ad6e9
Show file tree
Hide file tree
Showing 19 changed files with 280 additions and 182 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,22 @@

**[Documentation](https://open-retrievals.readthedocs.io)** | **[中文](https://github.com/LongxingTan/open-retrievals/blob/master/README_zh-CN.md)** | **[日本語](https://github.com/LongxingTan/open-retrievals/blob/master/README_ja-JP.md)**

![structure](./docs/source/_static/structure.png)

**Open-retrievals** simplifies text embeddings, retrievals, ranking, and RAG using PyTorch and Transformers. This user-friendly framework is designed for information retrieval and LLM generation.
- Embeddings, retrieval and rerank all-in-one: `AutoModelForEmbedding`
- Contrastive learning/LLM enhanced embeddings, with point-wise, pairwise and listwise fine-tuning
- Cross-encoder, ColBERT and LLM reranker
- Fast RAG easily integrated with Langchain and LlamaIndex

![structure](./docs/source/_static/structure.png)
| Exp | Model | Original | Finetune | Demo |
|----------------------------|-------------------------|----------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| embed pairwise finetune | bge-base-zh-v1.5 | 0.657 | **0.701** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
| embed llm finetune (LoRA) | Qwen2-1.5B-Instruct | 0.541 | **0.690** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
| rerank cross encoder | bge-reranker-base | 0.666 | **0.691** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing) |
| rerank colbert | chinese-roberta-wwm-ext | 0.643 | **0.683** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |

* The metrics is MAP in [t2-reranking data](https://huggingface.co/datasets/C-MTEB/T2Reranking). Original score of LLM and colbert original is Zero-shot


## Installation
Expand Down
4 changes: 2 additions & 2 deletions README_ja-JP.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

**[ドキュメント](https://open-retrievals.readthedocs.io)** | **[英語](https://github.com/LongxingTan/open-retrievals/blob/master/README.md)** | **[中文](https://github.com/LongxingTan/open-retrievals/blob/master/README_zh-CN.md)**

![structure](./docs/source/_static/structure.png)

**Open-Retrievals** は、PyTorch と Transformers をベースとした、情報検索と LLM 検索拡張生成を指向した、SOTA テキスト埋め込みを取得する使いやすい Python フレームワークです。
- `AutoModelForEmbedding` はベクトル化、検索、リランクの分野を統一します
- 対照学習エンベッディング, LLM エンベッディング
- 高速 RAG デモ

![structure](./docs/source/_static/structure.png)


## インストール

Expand Down
11 changes: 10 additions & 1 deletion README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,21 @@

**[中文wiki](https://github.com/LongxingTan/open-retrievals/wiki)** | **[英文文档](https://open-retrievals.readthedocs.io)** | **[Release Notes](https://open-retrievals.readthedocs.io/en/latest/CHANGELOG.html)**

![structure](./docs/source/_static/structure.png)

**Open-Retrievals** 帮助开发者在信息检索、大语言模型等领域便捷地应用文本向量,快速搭建检索、排序、RAG等应用。
- `AutoModelForEmbedding`一统向量、检索、重排
- 支持向量与重排模型多种微调方式,对比学习、大模型、point-wise、pairwise、listwise
- 定制化RAG框架,也支持在Langchain、LlamaIndex中便捷使用微调后的模型

![structure](./docs/source/_static/structure.png)
| 实验 | 模型 | 原分数 | 微调分数 | Demo代码 |
|-----------------|---------------------------|--------|-------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 向量pairwise微调 | bge-base-zh-v1.5 | 0.657 | **0.701** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing)|
| 向量大模型LoRA微调 | Qwen2-1.5B-Instruct | 0.541 | **0.690** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing)|
| cross encoder重排 | bge-reranker-base | 0.666 | **0.691** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing)|
| colbert重排 | chinese-roberta-wwm-ext | 0.643 | **0.683** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing)|

* 指标为[t2-reranking数据](https://huggingface.co/datasets/C-MTEB/T2Reranking)的MAP. 其中大模型与LLM的原分数为Zero-shot


## 安装
Expand Down
45 changes: 21 additions & 24 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@
- [rerank-llm finetune](../reference/rerank_llm_finetune.py)
- [RAG with Langchain](./rag_langchain_demo.py)

| Exp | Model | Original | Finetune | Colab |
|----------------------------|-------------------------|----------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| embed pairwise finetune | bge-base-zh-v1.5 | 0.657 | **0.701** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
| embed llm finetune (LoRA) | Qwen2-1.5B-Instruct | 0.554 | **-** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
| rerank cross encoder | bge-reranker-base | 0.666 | **0.691** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing) |
| rerank colbert (zero shot) | chinese-roberta-wwm-ext | 0.643 | **-** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |
| rerank llm finetune (LoRA) | Qwen2-1.5B-Instruct | | **-** | |

* The metrics is evaluated by MAP in t2-ranking data


## Retrieval

Expand Down Expand Up @@ -53,16 +43,16 @@ torchrun --nproc_per_node 1 \
--train_group_size 2 \
--logging_steps 100 \
--temperature 0.02 \
--use_inbatch_neg false
--use_inbatch_negative false
```

**Pairwise LLM embedding finetune**
- add query_instruction
- "Given a query and a relevant document, retrieve the document that are pertinent to the query\nQuery: "
- use the appropriate pooling_method
- last
- maybe reduce the batch_size due to large model size
- set use_lora to True if you want to use lora
- `last`
- maybe we need to reduce the batch_size due to large model size
- set `use_lora` to True if you want to use lora

```shell
MODEL_NAME="intfloat/e5-mistral-7b-instruct"
Expand All @@ -74,25 +64,26 @@ torchrun --nproc_per_node 1 \
--output_dir $OUTPUT_DIR \
--overwrite_output_dir \
--model_name_or_path $MODEL_NAME \
--pooling_method last \
--do_train \
--train_data $TRAIN_DATA \
--positive_key positive \
--negative_key negative \
--use_lora True \
--query_instruction "Given a query and a relevant document, retrieve the document that are pertinent to the query\nQuery: " \
--document_instruction '# Document: ' \
--learning_rate 3e-5 \
--query_instruction "Query: " \
--document_instruction "" \
--learning_rate 5e-5 \
--bf16 \
--num_train_epochs 5 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--dataloader_drop_last True \
--query_max_length 128 \
--query_max_length 256 \
--document_max_length 256 \
--train_group_size 2 \
--logging_steps 100 \
--temperature 0.02 \
--use_inbatch_neg false
--use_inbatch_negative false
```


Expand Down Expand Up @@ -144,19 +135,20 @@ torchrun --nproc_per_node 1 \
--positive_key positive \
--negative_key negative \
--learning_rate 1e-5 \
--fp16 \
--num_train_epochs 3 \
--bf16 \
--num_train_epochs 5 \
--per_device_train_batch_size 8 \
--dataloader_drop_last True \
--max_length 512 \
--train_group_size 8 \
--train_group_size 2 \
--unfold_each_positive false \
--save_total_limit 2 \
--logging_steps 100
--logging_steps 100 \
--use_inbatch_negative false
```

**LLM reranking**
- AutoModelForRanking.from_pretrained(model_name_or_path, causal_lm = True)
- `AutoModelForRanking.from_pretrained(model_name_or_path, causal_lm=True)`
- Prompt: "Given a query with a relevant body, determine whether the document is pertinent to the query by providing a prediction of either 'Yes' or 'No'."

```shell
Expand Down Expand Up @@ -190,3 +182,8 @@ torchrun --nproc_per_node 1 \
--save_total_limit 2 \
--bf16
```


## Common question
- If grad_norm during training is always zero, consider to change fp16 or bf16
- If the fine-tuned embedding performance during inference is worse, check whether the pooling_method is correct
9 changes: 6 additions & 3 deletions src/retrievals/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def __init__(
self.samples = self.generate_unfold_samples(dataset)
else:
self.dataset = dataset
logger.info("Generate total {} retrieval data.".format(len(self.dataset)))
logger.info(
"Generate total {} retrieval data. Query instruction: {}, Document instruction: {}".format(
len(self.dataset), self.query_instruction, self.document_instruction
)
)

def __len__(self) -> int:
return len(self.dataset)
Expand All @@ -82,7 +86,6 @@ def __getitem__(self, item: int) -> Union[Dict[str, str], List[BatchEncoding]]:
return self.samples[item]

data = self.dataset[item]

query = self.query_instruction + data[self.query_key]

if isinstance(data[self.positive_key], (list, tuple)):
Expand Down Expand Up @@ -128,7 +131,7 @@ def generate_unfold_samples(self, dataset):
return samples

def dynamic_sample(self, batch_size: int, missing_list=None, wrong_dict=None, max_wrong: int = 16):
logger.info('\nDynamic Shuffle Sample...')
logger.info('Dynamic Shuffle Sample')
return


Expand Down
2 changes: 1 addition & 1 deletion src/retrievals/losses/colbert_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class ColbertLoss(nn.Module):
def __init__(
self,
criterion: Union[nn.Module, Callable] = nn.CrossEntropyLoss(reduction='mean'),
temperature: float = 0.05,
temperature: float = 0.02,
use_inbatch_negative: bool = True,
):
super(ColbertLoss, self).__init__()
Expand Down
6 changes: 5 additions & 1 deletion src/retrievals/losses/infonce.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ def forward(
return loss
else:
negative_embeddings = F.normalize(negative_embeddings, dim=-1)

if self.use_inbatch_negative:
logits = torch.cat([positive_embeddings, negative_embeddings], dim=0)
similarity = query_embeddings @ logits.transpose(-2, -1)
similarity = similarity / self.temperature
similarity = similarity.view(query_embeddings.size(0), -1)
target = torch.arange(query_embeddings.size(0), dtype=torch.long, device=device)
else:
# -> [batch_size, embedding_size, num_negative_samples]
negative_embeddings = negative_embeddings.view(query_embeddings.size(0), -1, query_embeddings.size(1))
negative_embeddings = negative_embeddings.permute(0, 2, 1)
similarity = query_embeddings.unsqueeze(1) @ positive_embeddings.unsqueeze(2)
negative_similarity = query_embeddings.unsqueeze(1) @ negative_embeddings.unsqueeze(2)
negative_similarity = query_embeddings.unsqueeze(1) @ negative_embeddings
similarity = torch.cat([similarity.squeeze(1), negative_similarity.squeeze(1)], dim=1)
similarity = similarity / self.temperature
target = torch.zeros(query_embeddings.size(0), dtype=torch.long, device=device)
Expand Down
2 changes: 1 addition & 1 deletion src/retrievals/losses/pair_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
def forward(
self, query_embeddings: torch.Tensor, positive_embeddings: torch.Tensor, scores: torch.Tensor, **kwargs
):
similarity = torch.einsum('bn, bn -> b', query_embeddings, positive_embeddings)
similarity = torch.einsum('bn,bn->b', query_embeddings, positive_embeddings)
similarity = similarity / self.temperature
similarity = torch.log_softmax(similarity, dim=-1)
target = torch.softmax(scores / self.temperature, dim=-1)
Expand Down
2 changes: 0 additions & 2 deletions src/retrievals/losses/token_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging
from typing import Callable, Literal, Optional, Union

import torch
import torch.distributed.nn
import torch.nn as nn
import torch.nn.functional as F


class TokenLoss(nn.Module):
Expand Down
42 changes: 37 additions & 5 deletions src/retrievals/models/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import List, Optional, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer

logger = logging.getLogger(__name__)


class Base(ABC, torch.nn.Module):
def __init__(
Expand All @@ -28,11 +32,6 @@ def forward(self, *args, **kwargs):
"""Pytorch forward method."""
raise NotImplementedError

@abstractmethod
def encode(self, *args, **kwargs):
"""Encode documents."""
pass

def _encode_from_loader(
self,
loader: DataLoader,
Expand Down Expand Up @@ -84,5 +83,38 @@ def preprocess(self, batch_sentence_pair, query_max_length, document_max_length)
"doc_attention_mask": document_batch_tokens_on_device['attention_mask'],
}

def save_pretrained(self, path: str, safe_serialization: bool = True):
"""
Saves all model and tokenizer to path
"""
logger.info("Save model to {}".format(path))
state_dict = self.model.state_dict()
state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()})
self.model.save_pretrained(path, state_dict=state_dict, safe_serialization=safe_serialization)
self.tokenizer.save_pretrained(path)

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

def push_to_hub(self, hub_model_id: str, private: bool = True, **kwargs):
"""push model to hub
:param hub_model_id: str, hub model id.
:param private: bool, whether push to private repo. Default True.
:param kwargs: other kwargs for `push_to_hub` method.
"""
self.tokenizer.push_to_hub(hub_model_id, private=private, **kwargs)
self.backbone.push_to_hub(hub_model_id, private=private, **kwargs)

def _dist_gather_tensor(self, tensor: Optional[torch.Tensor]):
if tensor is None:
return None
tensor = tensor.contiguous()

all_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(all_tensors, tensor)

all_tensors[self.process_rank] = tensor
all_tensors = torch.cat(all_tensors, dim=0)

return all_tensors
Loading

0 comments on commit 26ad6e9

Please sign in to comment.