diff --git a/README.md b/README.md index 1b9a9425..d61bab65 100644 --- a/README.md +++ b/README.md @@ -240,14 +240,14 @@ model = AutoModelForEmbedding( from retrievals import AutoModelForEmbedding, AutoModelForRetrieval query_texts = ['A dog is chasing car.'] -passage_texts = ['A man is playing a guitar.', 'A bee is flying low'] +document_texts = ['A man is playing a guitar.', 'A bee is flying low'] model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2" model = AutoModelForEmbedding(model_name_or_path) query_embeddings = model.encode(query_texts, convert_to_tensor=True) -passage_embeddings = model.encode(passage_texts, convert_to_tensor=True) +document_embeddings = model.encode(document_texts, convert_to_tensor=True) matcher = AutoModelForRetrieval(method='cosine') -dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1) +dists, indices = matcher.similarity_search(query_embeddings, document_embeddings, top_k=1) ``` ## Reference & Acknowledge diff --git a/README_zh-CN.md b/README_zh-CN.md index 4bf2d2ec..2816b1b0 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -85,13 +85,13 @@ print(sentence_embeddings) from retrievals import AutoModelForEmbedding, AutoModelForRetrieval query_texts = [] -passage_texts = [] +document_texts = [] model = AutoModelForEmbedding('') query_embeddings = model.encode(query_texts, convert_to_tensor=True) -passage_embeddings = model.encode(passage_texts, convert_to_tensor=True) +document_embeddings = model.encode(document_texts, convert_to_tensor=True) matcher = AutoModelForRetrieval(method='cosine') -dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1) +dists, indices = matcher.similarity_search(query_embeddings, document_embeddings, top_k=1) ``` **Faiss向量数据库检索** diff --git a/examples/finetune_llm_embed.py b/examples/finetune_llm_embed.py index 5fc52895..54b26c8c 100644 --- a/examples/finetune_llm_embed.py +++ b/examples/finetune_llm_embed.py @@ -62,14 +62,14 @@ class DataArguments: query_max_len: int = field( default=32, metadata={ - "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "help": "The maximum total input sequence length after tokenization for document. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) - passage_max_len: int = field( + document_max_len: int = field( default=32, metadata={ - "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "help": "The maximum total input sequence length after tokenization for document. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) @@ -80,7 +80,7 @@ class DataArguments: query_instruction: str = field( default="Instruct: Retrieve semantically similar text.\nQuery: ", metadata={"help": "instruction for query"} ) - passage_instruction: str = field(default=None, metadata={"help": "instruction for passage"}) + document_instruction: str = field(default=None, metadata={"help": "instruction for document"}) def __post_init__(self): if not os.path.exists(self.train_data): diff --git a/examples/finetune_pairwise_embed.py b/examples/finetune_pairwise_embed.py index 1b0f18a2..eae03e24 100644 --- a/examples/finetune_pairwise_embed.py +++ b/examples/finetune_pairwise_embed.py @@ -59,15 +59,15 @@ class DataArguments: query_max_len: int = field( default=32, metadata={ - "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "help": "The maximum total input sequence length after tokenization for document. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) - passage_max_len: int = field( + document_max_len: int = field( default=128, metadata={ - "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "help": "The maximum total input sequence length after tokenization for document. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) @@ -78,7 +78,7 @@ class DataArguments: ) query_instruction: str = field(default=None, metadata={"help": "instruction for query"}) - passage_instruction: str = field(default=None, metadata={"help": "instruction for passage"}) + document_instruction: str = field(default=None, metadata={"help": "instruction for document"}) def __post_init__(self): if not os.path.exists(self.train_data): diff --git a/examples/rerank_cross_encoder.py b/examples/rerank_cross_encoder.py index 9727e29b..a191631b 100644 --- a/examples/rerank_cross_encoder.py +++ b/examples/rerank_cross_encoder.py @@ -56,14 +56,14 @@ class DataArguments: query_max_len: int = field( default=32, metadata={ - "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "help": "The maximum total input sequence length after tokenization for document. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) - passage_max_len: int = field( + document_max_len: int = field( default=32, metadata={ - "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "help": "The maximum total input sequence length after tokenization for document. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) @@ -74,7 +74,7 @@ class DataArguments: query_instruction: str = field( default="Instruct: Retrieve semantically similar text.\nQuery: ", metadata={"help": "instruction for query"} ) - passage_instruction: str = field(default=None, metadata={"help": "instruction for passage"}) + document_instruction: str = field(default=None, metadata={"help": "instruction for document"}) def __post_init__(self): if not os.path.exists(self.train_data): diff --git a/src/retrievals/data/collator.py b/src/retrievals/data/collator.py index 3761c310..ee0003f9 100644 --- a/src/retrievals/data/collator.py +++ b/src/retrievals/data/collator.py @@ -10,25 +10,25 @@ def __init__( tokenizer, max_length: Optional[int] = None, query_max_length: Optional[int] = None, - passage_max_length: Optional[int] = None, + document_max_length: Optional[int] = None, ) -> None: self.tokenizer = tokenizer if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None: self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.query_max_length: int - self.passage_max_length: int + self.document_max_length: int if query_max_length: self.query_max_length = query_max_length elif max_length: self.query_max_length = max_length - self.passage_max_length = max_length + self.document_max_length = max_length else: self.query_max_length = tokenizer.model_max_length - self.passage_max_length = tokenizer.model_max_length + self.document_max_length = tokenizer.model_max_length - if passage_max_length: - self.passage_max_length = passage_max_length + if document_max_length: + self.document_max_length = document_max_length def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: assert ( @@ -48,7 +48,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: pos_inputs = self.tokenizer( pos_texts, padding=True, - max_length=self.passage_max_length, + max_length=self.document_max_length, truncation=True, return_tensors="pt", ) @@ -62,25 +62,25 @@ def __init__( tokenizer, max_length: Optional[int] = None, query_max_length: Optional[int] = None, - passage_max_length: Optional[int] = None, + document_max_length: Optional[int] = None, ) -> None: self.tokenizer = tokenizer if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None: self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.query_max_length: int - self.passage_max_length: int + self.document_max_length: int if query_max_length: self.query_max_length = query_max_length elif max_length: self.query_max_length = max_length - self.passage_max_length = max_length + self.document_max_length = max_length else: self.query_max_length = tokenizer.model_max_length - self.passage_max_length = tokenizer.model_max_length + self.document_max_length = tokenizer.model_max_length - if passage_max_length: - self.passage_max_length = passage_max_length + if document_max_length: + self.document_max_length = document_max_length def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: assert ( @@ -93,8 +93,8 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: # if isinstance(query[0], list): # query = sum(query, []) - # if isinstance(passage[0], list): - # passage = sum(passage, []) + # if isinstance(document[0], list): + # document = sum(document, []) query_inputs = self.tokenizer( query_texts, @@ -106,14 +106,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: pos_inputs = self.tokenizer( pos_texts, padding=True, - max_length=self.passage_max_length, + max_length=self.document_max_length, truncation=True, return_tensors="pt", ) # ["input_ids"] neg_inputs = self.tokenizer( neg_texts, padding=True, - max_length=self.passage_max_length, + max_length=self.document_max_length, truncation=True, return_tensors="pt", ) # ["input_ids"] @@ -131,43 +131,43 @@ def __init__( tokenizer, max_length: Optional[int] = None, query_max_length: Optional[int] = None, - passage_max_length: Optional[int] = None, + document_max_length: Optional[int] = None, ): self.tokenizer = tokenizer if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None: self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.query_max_length: int - self.passage_max_length: int + self.document_max_length: int if query_max_length: self.query_max_length = query_max_length elif max_length: self.query_max_length = max_length - self.passage_max_length = max_length + self.document_max_length = max_length else: self.query_max_length = tokenizer.model_max_length - self.passage_max_length = tokenizer.model_max_length + self.document_max_length = tokenizer.model_max_length - if passage_max_length: - self.passage_max_length = passage_max_length + if document_max_length: + self.document_max_length = document_max_length def __call__(self, features: Union[List[Dict[str, Any]], List]) -> BatchEncoding: if isinstance(features[0], dict): assert ( - 'query' in features[0] and 'passage' in features[0] - ), "RerankCollator should have 'query' and 'passage' keys in features dict, and 'labels' during training" + 'query' in features[0] and 'document' in features[0] + ), "RerankCollator should have 'query' and 'document' keys in features dict, and 'labels' during training" query_texts = [feature["query"] for feature in features] - passage_texts = [feature['passage'] for feature in features] + document_texts = [feature['document'] for feature in features] else: query_texts = [feature[0] for feature in features] - passage_texts = [feature[1] for feature in features] + document_texts = [feature[1] for feature in features] labels = None if 'labels' in features[0].keys(): labels = [feature['labels'] for feature in features] batch = self.tokenizer( - text=query_texts, text_pair=passage_texts, truncation=True, max_length=self.max_length, return_tensors="pt" + text=query_texts, text_pair=document_texts, truncation=True, max_length=self.max_length, return_tensors="pt" ) # for key in ['input_ids', 'attention_mask']: diff --git a/src/retrievals/data/dataset.py b/src/retrievals/data/dataset.py index fd8a2af6..7e01af25 100644 --- a/src/retrievals/data/dataset.py +++ b/src/retrievals/data/dataset.py @@ -74,7 +74,7 @@ def __len__(self): def __getitem__(self, item): query = self.dataset[item]["query"] - passage = self.dataset[item]['passage'] + document = self.dataset[item]['document'] labels = self.dataset[item]['labels'] - sample = {"query": query, "passage": passage, "neg": labels} + sample = {"query": query, "document": document, "neg": labels} return sample diff --git a/src/retrievals/losses/cosine_similarity.py b/src/retrievals/losses/cosine_similarity.py index ec729822..3aa38963 100644 --- a/src/retrievals/losses/cosine_similarity.py +++ b/src/retrievals/losses/cosine_similarity.py @@ -16,12 +16,12 @@ def __init__(self, temperature: float = 0.0, dynamic_temperature=False): self.temperature = temperature self.dynamic_temperature = dynamic_temperature - def forward(self, query_embeddings: torch.Tensor, passage_embeddings: torch.Tensor): - sim_pos_vector = torch.cosine_similarity(query_embeddings, passage_embeddings, dim=-1) + def forward(self, query_embeddings: torch.Tensor, document_embeddings: torch.Tensor): + sim_pos_vector = torch.cosine_similarity(query_embeddings, document_embeddings, dim=-1) sim_pos_vector = sim_pos_vector / self.temperature sim_neg_matrix = torch.cosine_similarity( query_embeddings.unsqueeze(1), - passage_embeddings.unsqueeze(0), + document_embeddings.unsqueeze(0), dim=-1, ) sim_neg_matrix = sim_neg_matrix / self.temperature diff --git a/src/retrievals/models/embedding_auto.py b/src/retrievals/models/embedding_auto.py index 58ed9b05..b292eb3a 100644 --- a/src/retrievals/models/embedding_auto.py +++ b/src/retrievals/models/embedding_auto.py @@ -73,7 +73,7 @@ def __init__( max_length: Optional[int] = None, loss_fn: Optional[Callable] = None, query_instruction: Optional[str] = None, - passage_instruction: Optional[str] = None, + document_instruction: Optional[str] = None, generation_args: Dict = None, use_fp16: bool = False, use_lora: bool = False, @@ -130,7 +130,7 @@ def __init__( # self._init_weights(self.fc) self.query_instruction = query_instruction - self.passage_instruction = passage_instruction + self.document_instruction = document_instruction if generation_args is not None: generation_config = self.model.generation_config.to_dict() generation_config.update(generation_args) diff --git a/src/retrievals/models/rerank.py b/src/retrievals/models/rerank.py index 50a7e5e1..ded15738 100644 --- a/src/retrievals/models/rerank.py +++ b/src/retrievals/models/rerank.py @@ -145,7 +145,7 @@ def compute_score( with torch.no_grad(): scores_list: List = [] for i in range(0, len(text), batch_size): - text_batch = [{'query': text[i], 'passage': text_pair[i]} for i in range(i, i + batch_size)] + text_batch = [{'query': text[i], 'document': text_pair[i]} for i in range(i, i + batch_size)] batch = data_collator(text_batch) scores = ( self.model(batch['input_ids'], batch['attention_mask'], return_dict=True).logits.view(-1).float() @@ -158,30 +158,30 @@ def compute_score( def rerank( self, query: Union[List[str], str], - passages: List[str], + document: List[str], data_collator: Optional[RerankCollator] = None, batch_size: int = 32, show_progress_bar: bool = None, return_dict: bool = True, **kwargs, ): - merge_scores = self.compute_score(query, passages, data_collator, batch_size, show_progress_bar) + merge_scores = self.compute_score(query, document, data_collator, batch_size, show_progress_bar) merge_scores_argsort = np.argsort(merge_scores)[::-1] - sorted_passages = [] + sorted_document = [] sorted_scores = [] for mid in merge_scores_argsort: sorted_scores.append(merge_scores[mid]) - sorted_passages.append(passages[mid]) + sorted_document.append(document[mid]) if return_dict: return { - 'rerank_passages': sorted_passages, + 'rerank_document': sorted_document, 'rerank_scores': sorted_scores, 'rerank_ids': merge_scores_argsort.tolist(), } else: - return sorted_passages + return sorted_document def save(self, path: str): """ diff --git a/src/retrievals/models/retrieval_auto.py b/src/retrievals/models/retrieval_auto.py index ae8d9ed4..75459c9d 100644 --- a/src/retrievals/models/retrieval_auto.py +++ b/src/retrievals/models/retrieval_auto.py @@ -20,7 +20,7 @@ def __init__(self, method: Literal['cosine', 'knn'] = "cosine") -> None: def similarity_search( self, query_embed: torch.Tensor, - passage_embed: Optional[torch.Tensor] = None, + document_embed: Optional[torch.Tensor] = None, index_path: Optional[str] = None, top_k: int = 1, batch_size: int = -1, @@ -28,8 +28,8 @@ def similarity_search( **kwargs, ): self.top_k = top_k - if passage_embed is None and index_path is None: - logging.warning('Please provide passage_embed for knn/tensor search or index_path for faiss search') + if document_embed is None and index_path is None: + logging.warning('Please provide document_embed for knn/tensor search or index_path for faiss search') return if index_path is not None: faiss_index = faiss.read_index(index_path) @@ -42,12 +42,12 @@ def similarity_search( elif self.method == "knn": neighbors_model = NearestNeighbors(n_neighbors=top_k, metric="cosine", n_jobs=-1) - neighbors_model.fit(passage_embed) + neighbors_model.fit(document_embed) dists, indices = neighbors_model.kneighbors(query_embed) elif self.method == "cosine": dists, indices = cosine_similarity_search( - query_embed, passage_embed, top_k=top_k, batch_size=batch_size, convert_to_numpy=convert_to_numpy + query_embed, document_embed, top_k=top_k, batch_size=batch_size, convert_to_numpy=convert_to_numpy ) else: @@ -55,15 +55,15 @@ def similarity_search( return dists, indices - def get_pandas_candidate(self, query_ids, passage_ids, dists, indices): + def get_pandas_candidate(self, query_ids, document_ids, dists, indices): if isinstance(query_ids, pd.Series): query_ids = query_ids.values - if isinstance(passage_ids, pd.Series): - passage_ids = passage_ids.values + if isinstance(document_ids, pd.Series): + document_ids = document_ids.values retrieval = { 'query': np.repeat(query_ids, self.top_k), - 'passage': passage_ids[indices.ravel()], + 'document': document_ids[indices.ravel()], 'scores': dists.ravel(), } return pd.DataFrame(retrieval) @@ -76,7 +76,7 @@ def __init__(self, retrievers, weights=None): def cosine_similarity_search( query_embed: torch.Tensor, - passage_embed: torch.Tensor, + document_embed: torch.Tensor, top_k: int = 1, batch_size: int = 128, penalty: bool = True, @@ -86,9 +86,9 @@ def cosine_similarity_search( ): if len(query_embed.size()) == 1: query_embed = query_embed.view(1, -1) - assert query_embed.size()[1] == passage_embed.size()[1], ( - f"The embed Shape of query_embed and passage_embed should be same, " - f"while received query {query_embed.size()} and passage {passage_embed.size()}" + assert query_embed.size()[1] == document_embed.size()[1], ( + f"The embed Shape of query_embed and document_embed should be same, " + f"while received query {query_embed.size()} and document {document_embed.size()}" ) chunk = batch_size if batch_size > 0 else len(query_embed) embeddings_chunks = query_embed.split(chunk) @@ -96,7 +96,7 @@ def cosine_similarity_search( dists = [] indices = [] for idx in trange(0, len(embeddings_chunks), desc="Batches", disable=not show_progress_bar): - cos_sim_chunk = torch.matmul(embeddings_chunks[idx], passage_embed.transpose(0, 1)) + cos_sim_chunk = torch.matmul(embeddings_chunks[idx], document_embed.transpose(0, 1)) cos_sim_chunk = torch.nan_to_num(cos_sim_chunk, nan=0.0) # if penalty: # pen = ((contents["old_source_count"].values==0) & (contents["old_nonsource_count"].values==1)) diff --git a/src/retrievals/trainer/custom_trainer.py b/src/retrievals/trainer/custom_trainer.py index 9069add6..c39c99a1 100644 --- a/src/retrievals/trainer/custom_trainer.py +++ b/src/retrievals/trainer/custom_trainer.py @@ -2,7 +2,7 @@ import logging import math import time -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -13,11 +13,11 @@ def train_fn( - epoch, - model, + epoch: int, + model: nn.Module, train_loader, optimizer, - criterion=None, + criterion: Optional[Callable] = None, apex: bool = False, gradient_accumulation_steps: int = 1, max_grad_norm: float = 1, @@ -223,15 +223,15 @@ def inference_fn(test_loader, model, device): class CustomTrainer(object): - def __init__(self, model: Union[str, nn.Module], device=None, apex=False, teacher=None): + def __init__(self, model: Union[str, nn.Module], device: Optional[str] = None, apex=False, teacher=None): if not device: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.model = model self.teacher = teacher self.apex = apex - self.train_step = train_fn - self.valid_step = valid_fn + self.train_fn = train_fn + self.valid_fn = valid_fn def train( self, @@ -246,6 +246,7 @@ def train( max_grad_norm=10, **kwargs, ): + logger.info('-------START TO TRAIN-------') # best_score = 0 # fgm = FGM(model) # awp = None @@ -271,7 +272,7 @@ def train( # 如果把参数直接放在cuda, 则导致parameter里没有出现该参数,无法分别设置学习率 self.model = self.model.to(self.device) optimizer.zero_grad() - self.train_step( + self.train_fn( epoch=epoch, train_loader=train_loader, model=self.model, @@ -287,7 +288,7 @@ def train( ) if valid_loader is not None: - self.valid_step( + self.valid_fn( epoch=epoch, valid_loader=valid_loader, model=self.model, diff --git a/tests/test_losses/test_cosine_similarity.py b/tests/test_losses/test_cosine_similarity.py index e4d7d3c7..37bfd27a 100644 --- a/tests/test_losses/test_cosine_similarity.py +++ b/tests/test_losses/test_cosine_similarity.py @@ -9,7 +9,7 @@ class CosineSimilarityTest(TestCase): def setUp(self): # Setup can adjust parameters for wide coverage of scenarios self.query_embeddings = torch.randn(10, 128) # Example embeddings - self.passage_embeddings = torch.randn(10, 128) + self.document_embeddings = torch.randn(10, 128) self.temperature = 0.1 def test_loss_computation(self): @@ -17,7 +17,7 @@ def test_loss_computation(self): module = CosineSimilarity(temperature=self.temperature) # Compute loss - loss = module(self.query_embeddings, self.passage_embeddings) + loss = module(self.query_embeddings, self.document_embeddings) # Check if loss is a single scalar value and not nan or inf self.assertTrue(torch.isfinite(loss)) @@ -25,11 +25,11 @@ def test_loss_computation(self): def test_temperature_effect(self): # High temperature high_temp_module = CosineSimilarity(temperature=100.0) - high_temp_loss = high_temp_module(self.query_embeddings, self.passage_embeddings) + high_temp_loss = high_temp_module(self.query_embeddings, self.document_embeddings) # Low temperature low_temp_module = CosineSimilarity(temperature=0.01) - low_temp_loss = low_temp_module(self.query_embeddings, self.passage_embeddings) + low_temp_loss = low_temp_module(self.query_embeddings, self.document_embeddings) # Expect the loss to be higher for the lower temperature due to sharper softmax self.assertTrue(low_temp_loss > high_temp_loss)