Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update rerank and retrieval #23

Merged
merged 6 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,16 @@ def __init__(
if passage_max_length:
self.passage_max_length = passage_max_length

def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
assert (
'query' in features and 'passage' in features
), "RerankCollator should have 'query' and 'passage' keys in features dict, and 'labels' during training"

query_texts = [feature["query"] for feature in features]
passage_texts = [feature['passage'] for feature in features]
def __call__(self, features: Union[List[Dict[str, Any]], List]) -> BatchEncoding:
if isinstance(features[0], dict):
assert (
'query' in features[0] and 'passage' in features[0]
), "RerankCollator should have 'query' and 'passage' keys in features dict, and 'labels' during training"
query_texts = [feature["query"] for feature in features]
passage_texts = [feature['passage'] for feature in features]
else:
query_texts = [feature[0] for feature in features]
passage_texts = [feature[1] for feature in features]

labels = None
if 'labels' in features[0].keys():
Expand Down
20 changes: 14 additions & 6 deletions src/retrievals/models/embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def batch_to_device(batch: Dict, target_device: str) -> Dict[str, torch.Tensor]:

class AutoModelForEmbedding(nn.Module):
"""
Loads or creates a Embedding model that can be used to map sentences / text.
Loads or creates an Embedding model that can be used to map sentences / text.

:param model_name_or_path: If it is a filepath on disc, it loads the model from that path. If it is not a path,
it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model
Expand All @@ -79,7 +79,7 @@ def __init__(
use_lora: bool = False,
lora_config=None,
device: Optional[str] = None,
trust_remote_code: bool = False,
trust_remote_code: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -201,7 +201,7 @@ def forward_from_text(self, texts):

def encode(
self,
inputs,
inputs: Union[DataLoader, Dict, List, str],
batch_size: int = 128,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
Expand Down Expand Up @@ -380,7 +380,7 @@ def encode_from_text(

return all_embeddings

def build_index(self, inputs: BatchEncoding, batch_size: int = 64, use_gpu: bool = True):
def build_index(self, inputs: BatchEncoding, batch_size: int = 128, use_gpu: bool = True):
embeddings = self.encode(inputs, batch_size=batch_size)
embeddings = np.asarray(embeddings, dtype=np.float32)
index = faiss.IndexFlatL2(len(embeddings[0]))
Expand All @@ -401,8 +401,16 @@ def search(self):
def similarity(self, queries: Union[str, List[str]], keys: Union[str, List[str], ndarray]):
return

def save(self):
pass
def save(self, path: str):
"""
Saves all model and tokenizer to path
"""
if path is None:
return

logger.info("Save model to {}".format(path))
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)

@classmethod
def from_pretrained(cls, model_name_or_path: str, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions src/retrievals/models/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger(__name__)


class RAG(object):
class SimpleRAG(object):
def __init__(self):
pass

Expand Down Expand Up @@ -44,7 +44,7 @@ def search(self):
return


class ChatGenerator(object):
class Generator(object):
def __init__(self, config_path: str):
self.config_path = config_path

Expand Down
20 changes: 15 additions & 5 deletions src/retrievals/models/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers import (
AutoConfig,
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
)

from src.retrievals.data.collator import RerankCollator
from src.retrievals.models.embedding_auto import get_device_name
Expand Down Expand Up @@ -33,7 +38,9 @@ def __init__(
model_name_or_path, return_tensors=False, trust_remote_code=trust_remote_code
)

self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
)
if gradient_checkpointing:
self.model.graident_checkpointing_enable()
if device is None:
Expand All @@ -53,9 +60,9 @@ def __init__(
self.model.print_trainable_parameters()

self.pooling = AutoPooling(pooling_method)
num_features = self.backbone.config.hidden_size
num_features = self.model.config.hidden_size
self.classifier = nn.Linear(num_features, 1)
self._init_weights(self.classifier)
# self._init_weights(self.classifier)
self.loss_fn = loss_fn

if max_length is None:
Expand Down Expand Up @@ -130,6 +137,7 @@ def compute_score(
if isinstance(text_pair, str):
text_pair = [text_pair]
assert len(text) == len(text_pair), f"Length of text {len(text)} and text_pair {len(text_pair)} should be same"
batch_size = min(batch_size, len(text))

if not data_collator:
data_collator = RerankCollator(tokenizer=self.tokenizer)
Expand All @@ -139,7 +147,9 @@ def compute_score(
for i in range(0, len(text), batch_size):
text_batch = [{'query': text[i], 'passage': text_pair[i]} for i in range(i, i + batch_size)]
batch = data_collator(text_batch)
scores = self.model(batch, return_dict=True).logits.view(-1).float()
scores = (
self.model(batch['input_ids'], batch['attention_mask'], return_dict=True).logits.view(-1).float()
)
scores = torch.sigmoid(scores)
scores_list.extend(scores.cpu().numpy().tolist())

Expand Down
56 changes: 27 additions & 29 deletions src/retrievals/models/retrieval_auto.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from typing import Union
from typing import Optional, Union

import faiss
import numpy as np
import pandas as pd
import torch
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
Expand All @@ -19,47 +20,49 @@ def __init__(self, method: str = "cosine") -> None:
def similarity_search(
self,
query_embed: torch.Tensor,
passage_embed: torch.Tensor,
passage_embed: Optional[torch.Tensor] = None,
index_path: Optional[str] = None,
top_k: int = 1,
batch_size: int = -1,
convert_to_numpy: bool = True,
convert_to_pandas: bool = False,
**kwargs,
):
if self.method == "knn":
if passage_embed is None and index_path is None:
logging.warning('Please provide passage_embed for knn/tensor search or index_path for faiss search')
return
if index_path is not None:
faiss_index = faiss.read_index(index_path)
dists, indices = faiss_search(
query_embeddings=query_embed,
faiss_index=faiss_index,
top_k=top_k,
batch_size=batch_size,
)

elif self.method == "knn":
neighbors_model = NearestNeighbors(n_neighbors=top_k, metric="cosine", n_jobs=-1)
neighbors_model.fit(passage_embed)
dists, indices = neighbors_model.kneighbors(query_embed)
return dists, indices

elif self.method == "cosine":
dists, indices = cosine_similarity_search(
query_embed, passage_embed, top_k=top_k, batch_size=batch_size, convert_to_numpy=convert_to_numpy
)
return dists, indices

else:
raise ValueError(f"Only cosine and knn method are supported by similarity_search, while get {self.method}")

def faiss_search(
self,
query_embed: torch.Tensor,
index_path: str = "/faiss.index",
top_k: int = 1,
batch_size: int = 128,
max_length: int = 512,
):
faiss_index = faiss.read_index(index_path)
dists, indices = faiss_search(
query_embeddings=query_embed,
faiss_index=faiss_index,
top_k=top_k,
batch_size=batch_size,
)
return dists, indices
if not convert_to_pandas:
return dists, indices

def get_rerank_df(self):
rerank_data = dict({'query': [], 'passage': [], 'labels': []})
return rerank_data
retrieval = dict({'query': [], 'passage': [], 'labels': []})
return pd.from_dict(retrieval)


class EnsembleRetriever(object):
def __init__(self, retrievers, weights=None):
pass


def cosine_similarity_search(
Expand Down Expand Up @@ -136,11 +139,6 @@ def faiss_search(
return all_scores, all_indices


class EnsembleRetriever(object):
def __init__(self, retrievers, weights=None):
pass


class FaissIndex:
def __init__(self, device) -> None:
if isinstance(device, torch.device):
Expand Down
5 changes: 0 additions & 5 deletions src/retrievals/tools/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,3 @@ def compress_documents(

final_results = final_results[: self.top_n]
return final_results


class RagFeature(object):
def __init__(self, config_path: str = 'config.ini'):
pass
Loading
Loading