Skip to content

Commit

Permalink
feat: build hard negative for retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Sep 23, 2024
1 parent 13d30f8 commit 4738bb3
Show file tree
Hide file tree
Showing 20 changed files with 296 additions and 148 deletions.
15 changes: 2 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

![structure](./docs/source/_static/structure.png)

**Open-retrievals** unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable.
**Open-retrievals** unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable to fine-tune the model.
- Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning and LLM.
- Reranking fine-tuned with Cross-Encoder, ColBERT and LLM.
- Easily build enhanced modular RAG, integrated with Transformers, Langchain and LlamaIndex.
Expand All @@ -54,23 +54,12 @@

## Installation

**Prerequisites**
```shell
pip install transformers
pip install faiss-cpu # if necessary while faiss retrieval
pip install peft # if necessary while LoRA training
```

**With pip**
```shell
pip install transformers
pip install open-retrievals
```

**With source code**
```shell
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
```


## Quick-start

Expand Down
25 changes: 7 additions & 18 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

![structure](./docs/source/_static/structure.png)

**Open-Retrievals** 支持统一调用或微调文本向量、检索、重排等模型,使信息检索、RAG应用更加便捷
- 支持全套向量微调,对比学习、大模型、point-wise、pairwise、listwise
- 支持全套重排微调,cross-encoder、ColBERT、LLM
**Open-Retrievals** 统一调用和微调文本向量、检索、重排模型,使信息检索、RAG应用更加便捷
- 支持文本向量微调,对比学习、大模型、point-wise、pairwise、listwise
- 支持重排微调,cross-encoder、ColBERT、LLM
- 支持定制化、模块化RAG,支持在Transformers、Langchain、LlamaIndex中便捷使用微调后的模型

| 实验 | 模型 | 原分数 | 微调分数 | Demo代码 |
Expand All @@ -54,23 +54,12 @@

## 安装

**基础**
```shell
pip install transformers
pip install faiss # 如有必要,检索
pip install peft # 如有必要,LoRA训练
```

**pip安装**
```shell
pip install transformers
pip install open-retrievals
```

**源码安装**
```shell
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
```


## 快速入门

Expand Down Expand Up @@ -306,7 +295,7 @@ trainer.train()

</details>

<details><summary> 微调Cross-encoder重排模型 </summary>
<details><summary> 微调Cross-encoder重排 </summary>

```python
import os
Expand Down Expand Up @@ -356,7 +345,7 @@ trainer.train()

</details>

<details><summary> 微调ColBERT重排模型 </summary>
<details><summary> 微调ColBERT重排 </summary>

```python
import os
Expand Down Expand Up @@ -422,7 +411,7 @@ trainer.train()

</details>

<details><summary> 微调大模型重排模型 </summary>
<details><summary> 微调大模型重排 </summary>

```python
import os
Expand Down
32 changes: 25 additions & 7 deletions docs/source/embed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ Prepare data
Pair wise
~~~~~~~~~~~~~

If the positive and negative examples have some noise in label, the directly point-wise cross-entropy maybe not the best. The pair wise just compare relatively, or the hinge loss with margin could be better.

.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing
:alt: Open In Colab
Expand Down Expand Up @@ -188,9 +190,9 @@ Pair wise
Point wise
~~~~~~~~~~~~~~~~~~

If the positive and negative examples have some noise in label, the directly point-wise cross-entropy maybe not the best. The pair wise just compare relatively, or the hinge loss with margin could be better.
We can use point-wise train, similar to use `tfidf` in information retrieval.

arcface
**arcface**

- layer wise learning rate
- batch size is important
Expand All @@ -202,7 +204,6 @@ List wise
~~~~~~~~~~~~~~~~~~



3. Training skills to enhance the performance
----------------------------------------------

Expand All @@ -225,14 +226,31 @@ tuning the important parameters:


Hard negative mining
~~~~~~~~~~~~~~~~~~~~~~~~
offline hard mining
~~~~~~~~~~~~~~~~~~~~~~~~~

- offline hard mining or online hard mining

If we only have query and positive, we can use it to generate more negative samples to enhance the retrieval performance.

online hard mining
The data format of `input_file` to generate hard negative is `(query, positive)` or `(query, positive, negative)`
The format of `candidate_pool` of corpus is jsonl of `{text}`


.. code-block:: shell
python -m retrievals.pipelines.build_hn \
--model_name_or_path BAAI/bge-base-en-v1.5 \
--input_file /t2_ranking.jsonl \
--output_file /t2_ranking_hn.jsonl \
--positive_key positive \
--negative_key negative \
--range_for_sampling 2-200 \
--negative_number 15 \
Matryoshka Representation Learning
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



Contrastive loss
Expand Down
6 changes: 6 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ Now you are ready, proceed with
# install with support of evaluation
pip install open-retrievals[eval]
Or install from source code

.. code-block:: shell
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
Examples
------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/rag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Integrated with Langchain
rerank_model_name_or_path = "BAAI/bge-reranker-base"
llm_model_name_or_path = "microsoft/Phi-3-mini-128k-instruct"
embeddings = LangchainEmbedding(model_name_or_path=embed_model_name_or_path)
embeddings = LangchainEmbedding(model_name_or_path=embed_model_name_or_path, model_kwargs={'pooling_method': 'mean'})
vectordb = Vectorstore(
persist_directory=persist_directory,
embedding_function=embeddings,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/retrieval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Retrieval
1. Pipeline
----------------------------

The retrieval method could solve the **search** or **extreme multiclass classification** problem.

generate data -> train -> eval

pretrained encoding -> build hard negative -> train -> eval -> indexing -> retrieval
Expand Down
4 changes: 1 addition & 3 deletions examples/0_embedding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@
```



Train directly using shell script, refer to the [document](https://open-retrievals.readthedocs.io/en/master/embed.html)

## Encoder embedding
## Transformer encoder embedding

Refer to [the fine-tuning code](./train_pairwise.py) to train the model like



## LLM embedding

Refer to [the fine-tuning code](./train_llm.py), to train the model like
Expand Down
3 changes: 0 additions & 3 deletions examples/0_embedding/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,6 @@ def __getitem__(self, item):
query = self.dataset[item]["query"] + self.tokenizer.eos_token
pos = self.dataset[item]["pos"][0] + self.tokenizer.eos_token
neg = self.dataset[item]["neg"][0] + self.tokenizer.eos_token
# pos = random.choice(self.dataset[item]["pos"])
# neg = random.choice(self.dataset[item]["neg"])

res = {"query": query, "pos": pos, "neg": neg}
return res

Expand Down
4 changes: 2 additions & 2 deletions examples/scifact/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from datasets import load_dataset

from retrievals.metrics import get_mrr, get_ndcg, get_recall
from retrievals.metrics import get_fbeta, get_mrr, get_ndcg


def transfer_index_to_id(save_path):
Expand Down Expand Up @@ -56,7 +56,7 @@ def transfer_index_to_id(save_path):
qid2ranking[qid].append(pid)

results = get_mrr(qid2positives, qid2ranking, cutoff_rank=10)
results.update(get_recall(qid2positives, qid2ranking, cutoff_ranks=[10]))
results.update(get_fbeta(qid2positives, qid2ranking, cutoff_ranks=[10]))
results.update(get_ndcg(qid2positives, qid2ranking, cutoff_rank=10))

print(json.dumps(results, indent=4))
3 changes: 2 additions & 1 deletion src/retrievals/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .fbeta import get_recall
from .fbeta import get_fbeta
from .hit_rate import get_hit_rate
from .map import get_map
from .mrr import get_mrr
from .ndcg import get_ndcg
Expand Down
4 changes: 3 additions & 1 deletion src/retrievals/metrics/fbeta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Dict, List


def get_recall(qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_ranks: List[int] = [10]):
def get_fbeta(
qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_ranks: List[int] = [10], beta: int = 2
):
qid2recall = {cutoff_rank: {} for cutoff_rank in cutoff_ranks}
num_samples = len(qid2ranking.keys())

Expand Down
29 changes: 29 additions & 0 deletions src/retrievals/metrics/hit_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Dict, List


def get_hit_rate(qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_rank: int = 10):
"""
qid2positive (order doesn't matter): {qid: [pos1_doc_id, pos2_doc_id]}
qid2ranking (order does matter): {qid: [rank1_doc_id, rank2_doc_id, rank3_doc_id]}
"""

def hit_rate(positives_ids: List[str], ranked_doc_ids: List[str], cutoff: int) -> float:
"""
Calculate hit rate at the specified cutoff
"""
hits = 0

for doc_id in ranked_doc_ids[:cutoff]:
if doc_id in positives_ids:
hits += 1

return hits / cutoff if cutoff > 0 else 0.0

qid2hr = dict()

for qid in qid2positive:
positives_ids = qid2positive[qid]
ranked_doc_ids = qid2ranking[qid]
qid2hr[qid] = hit_rate(positives_ids, ranked_doc_ids, cutoff_rank)

return {f"hit_rate@{cutoff_rank}": sum(qid2hr.values()) / len(qid2hr) if qid2hr else 0.0}
16 changes: 8 additions & 8 deletions src/retrievals/metrics/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@

def get_map(qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_rank: int = 10):
"""
qid2positive: {qid: [pos1_doc_id, pos2_doc_id]}
qid2ranking: {qid: [rank1_doc_id, rank2_doc_id, rank3_doc_id]}
qid2positive (order doesn't matter): {qid: [pos1_doc_id, pos2_doc_id]}
qid2ranking (order does matter): {qid: [rank1_doc_id, rank2_doc_id, rank3_doc_id]}
"""

def average_precision(positives: List[str], ranked_doc_ids: List[str], cutoff: int) -> float:
def average_precision(positives_ids: List[str], ranked_doc_ids: List[str], cutoff: int) -> float:
"""
for each cut_off, calculate its precision
Average of precision for each cut_off
"""
hits = 0
sum_precisions = 0.0

for rank, doc_id in enumerate(ranked_doc_ids[:cutoff], start=1):
if doc_id in positives:
if doc_id in positives_ids:
hits += 1
sum_precisions += hits / rank

return sum_precisions / len(positives) if positives else 0.0
return sum_precisions / min(len(positives_ids), cutoff) if positives_ids else 0.0

qid2map = dict()

for qid in qid2positive:
positives = qid2positive[qid]
positives_ids = qid2positive[qid]
ranked_doc_ids = qid2ranking[qid]
qid2map[qid] = average_precision(positives, ranked_doc_ids, cutoff_rank)
qid2map[qid] = average_precision(positives_ids, ranked_doc_ids, cutoff_rank)

return {f"map@{cutoff_rank}": sum(qid2map.values()) / len(qid2ranking.keys())}
Loading

0 comments on commit 4738bb3

Please sign in to comment.