Skip to content

Commit

Permalink
fix: rename passage to document (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Apr 18, 2024
1 parent 3eac842 commit 9b768c6
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 87 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,14 @@ model = AutoModelForEmbedding(
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

query_texts = ['A dog is chasing car.']
passage_texts = ['A man is playing a guitar.', 'A bee is flying low']
document_texts = ['A man is playing a guitar.', 'A bee is flying low']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)
document_embeddings = model.encode(document_texts, convert_to_tensor=True)

matcher = AutoModelForRetrieval(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1)
dists, indices = matcher.similarity_search(query_embeddings, document_embeddings, top_k=1)
```

## Reference & Acknowledge
Expand Down
6 changes: 3 additions & 3 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ print(sentence_embeddings)
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval

query_texts = []
passage_texts = []
document_texts = []
model = AutoModelForEmbedding('')
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)
document_embeddings = model.encode(document_texts, convert_to_tensor=True)

matcher = AutoModelForRetrieval(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1)
dists, indices = matcher.similarity_search(query_embeddings, document_embeddings, top_k=1)
```

**Faiss向量数据库检索**
Expand Down
8 changes: 4 additions & 4 deletions examples/finetune_llm_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ class DataArguments:
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"help": "The maximum total input sequence length after tokenization for document. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
passage_max_len: int = field(
document_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"help": "The maximum total input sequence length after tokenization for document. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
Expand All @@ -80,7 +80,7 @@ class DataArguments:
query_instruction: str = field(
default="Instruct: Retrieve semantically similar text.\nQuery: ", metadata={"help": "instruction for query"}
)
passage_instruction: str = field(default=None, metadata={"help": "instruction for passage"})
document_instruction: str = field(default=None, metadata={"help": "instruction for document"})

def __post_init__(self):
if not os.path.exists(self.train_data):
Expand Down
8 changes: 4 additions & 4 deletions examples/finetune_pairwise_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ class DataArguments:
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"help": "The maximum total input sequence length after tokenization for document. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)

passage_max_len: int = field(
document_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"help": "The maximum total input sequence length after tokenization for document. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
Expand All @@ -78,7 +78,7 @@ class DataArguments:
)

query_instruction: str = field(default=None, metadata={"help": "instruction for query"})
passage_instruction: str = field(default=None, metadata={"help": "instruction for passage"})
document_instruction: str = field(default=None, metadata={"help": "instruction for document"})

def __post_init__(self):
if not os.path.exists(self.train_data):
Expand Down
8 changes: 4 additions & 4 deletions examples/rerank_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ class DataArguments:
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"help": "The maximum total input sequence length after tokenization for document. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
passage_max_len: int = field(
document_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"help": "The maximum total input sequence length after tokenization for document. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
Expand All @@ -74,7 +74,7 @@ class DataArguments:
query_instruction: str = field(
default="Instruct: Retrieve semantically similar text.\nQuery: ", metadata={"help": "instruction for query"}
)
passage_instruction: str = field(default=None, metadata={"help": "instruction for passage"})
document_instruction: str = field(default=None, metadata={"help": "instruction for document"})

def __post_init__(self):
if not os.path.exists(self.train_data):
Expand Down
56 changes: 28 additions & 28 deletions src/retrievals/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@ def __init__(
tokenizer,
max_length: Optional[int] = None,
query_max_length: Optional[int] = None,
passage_max_length: Optional[int] = None,
document_max_length: Optional[int] = None,
) -> None:
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

self.query_max_length: int
self.passage_max_length: int
self.document_max_length: int
if query_max_length:
self.query_max_length = query_max_length
elif max_length:
self.query_max_length = max_length
self.passage_max_length = max_length
self.document_max_length = max_length
else:
self.query_max_length = tokenizer.model_max_length
self.passage_max_length = tokenizer.model_max_length
self.document_max_length = tokenizer.model_max_length

if passage_max_length:
self.passage_max_length = passage_max_length
if document_max_length:
self.document_max_length = document_max_length

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
assert (
Expand All @@ -48,7 +48,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
pos_inputs = self.tokenizer(
pos_texts,
padding=True,
max_length=self.passage_max_length,
max_length=self.document_max_length,
truncation=True,
return_tensors="pt",
)
Expand All @@ -62,25 +62,25 @@ def __init__(
tokenizer,
max_length: Optional[int] = None,
query_max_length: Optional[int] = None,
passage_max_length: Optional[int] = None,
document_max_length: Optional[int] = None,
) -> None:
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

self.query_max_length: int
self.passage_max_length: int
self.document_max_length: int
if query_max_length:
self.query_max_length = query_max_length
elif max_length:
self.query_max_length = max_length
self.passage_max_length = max_length
self.document_max_length = max_length
else:
self.query_max_length = tokenizer.model_max_length
self.passage_max_length = tokenizer.model_max_length
self.document_max_length = tokenizer.model_max_length

if passage_max_length:
self.passage_max_length = passage_max_length
if document_max_length:
self.document_max_length = document_max_length

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
assert (
Expand All @@ -93,8 +93,8 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:

# if isinstance(query[0], list):
# query = sum(query, [])
# if isinstance(passage[0], list):
# passage = sum(passage, [])
# if isinstance(document[0], list):
# document = sum(document, [])

query_inputs = self.tokenizer(
query_texts,
Expand All @@ -106,14 +106,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
pos_inputs = self.tokenizer(
pos_texts,
padding=True,
max_length=self.passage_max_length,
max_length=self.document_max_length,
truncation=True,
return_tensors="pt",
) # ["input_ids"]
neg_inputs = self.tokenizer(
neg_texts,
padding=True,
max_length=self.passage_max_length,
max_length=self.document_max_length,
truncation=True,
return_tensors="pt",
) # ["input_ids"]
Expand All @@ -131,43 +131,43 @@ def __init__(
tokenizer,
max_length: Optional[int] = None,
query_max_length: Optional[int] = None,
passage_max_length: Optional[int] = None,
document_max_length: Optional[int] = None,
):
self.tokenizer = tokenizer
if not hasattr(self.tokenizer, "pad_token_id") or self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

self.query_max_length: int
self.passage_max_length: int
self.document_max_length: int
if query_max_length:
self.query_max_length = query_max_length
elif max_length:
self.query_max_length = max_length
self.passage_max_length = max_length
self.document_max_length = max_length
else:
self.query_max_length = tokenizer.model_max_length
self.passage_max_length = tokenizer.model_max_length
self.document_max_length = tokenizer.model_max_length

if passage_max_length:
self.passage_max_length = passage_max_length
if document_max_length:
self.document_max_length = document_max_length

def __call__(self, features: Union[List[Dict[str, Any]], List]) -> BatchEncoding:
if isinstance(features[0], dict):
assert (
'query' in features[0] and 'passage' in features[0]
), "RerankCollator should have 'query' and 'passage' keys in features dict, and 'labels' during training"
'query' in features[0] and 'document' in features[0]
), "RerankCollator should have 'query' and 'document' keys in features dict, and 'labels' during training"
query_texts = [feature["query"] for feature in features]
passage_texts = [feature['passage'] for feature in features]
document_texts = [feature['document'] for feature in features]
else:
query_texts = [feature[0] for feature in features]
passage_texts = [feature[1] for feature in features]
document_texts = [feature[1] for feature in features]

labels = None
if 'labels' in features[0].keys():
labels = [feature['labels'] for feature in features]

batch = self.tokenizer(
text=query_texts, text_pair=passage_texts, truncation=True, max_length=self.max_length, return_tensors="pt"
text=query_texts, text_pair=document_texts, truncation=True, max_length=self.max_length, return_tensors="pt"
)

# for key in ['input_ids', 'attention_mask']:
Expand Down
4 changes: 2 additions & 2 deletions src/retrievals/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __len__(self):

def __getitem__(self, item):
query = self.dataset[item]["query"]
passage = self.dataset[item]['passage']
document = self.dataset[item]['document']
labels = self.dataset[item]['labels']
sample = {"query": query, "passage": passage, "neg": labels}
sample = {"query": query, "document": document, "neg": labels}
return sample
6 changes: 3 additions & 3 deletions src/retrievals/losses/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def __init__(self, temperature: float = 0.0, dynamic_temperature=False):
self.temperature = temperature
self.dynamic_temperature = dynamic_temperature

def forward(self, query_embeddings: torch.Tensor, passage_embeddings: torch.Tensor):
sim_pos_vector = torch.cosine_similarity(query_embeddings, passage_embeddings, dim=-1)
def forward(self, query_embeddings: torch.Tensor, document_embeddings: torch.Tensor):
sim_pos_vector = torch.cosine_similarity(query_embeddings, document_embeddings, dim=-1)
sim_pos_vector = sim_pos_vector / self.temperature
sim_neg_matrix = torch.cosine_similarity(
query_embeddings.unsqueeze(1),
passage_embeddings.unsqueeze(0),
document_embeddings.unsqueeze(0),
dim=-1,
)
sim_neg_matrix = sim_neg_matrix / self.temperature
Expand Down
4 changes: 2 additions & 2 deletions src/retrievals/models/embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
max_length: Optional[int] = None,
loss_fn: Optional[Callable] = None,
query_instruction: Optional[str] = None,
passage_instruction: Optional[str] = None,
document_instruction: Optional[str] = None,
generation_args: Dict = None,
use_fp16: bool = False,
use_lora: bool = False,
Expand Down Expand Up @@ -130,7 +130,7 @@ def __init__(
# self._init_weights(self.fc)

self.query_instruction = query_instruction
self.passage_instruction = passage_instruction
self.document_instruction = document_instruction
if generation_args is not None:
generation_config = self.model.generation_config.to_dict()
generation_config.update(generation_args)
Expand Down
14 changes: 7 additions & 7 deletions src/retrievals/models/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def compute_score(
with torch.no_grad():
scores_list: List = []
for i in range(0, len(text), batch_size):
text_batch = [{'query': text[i], 'passage': text_pair[i]} for i in range(i, i + batch_size)]
text_batch = [{'query': text[i], 'document': text_pair[i]} for i in range(i, i + batch_size)]
batch = data_collator(text_batch)
scores = (
self.model(batch['input_ids'], batch['attention_mask'], return_dict=True).logits.view(-1).float()
Expand All @@ -158,30 +158,30 @@ def compute_score(
def rerank(
self,
query: Union[List[str], str],
passages: List[str],
document: List[str],
data_collator: Optional[RerankCollator] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
return_dict: bool = True,
**kwargs,
):
merge_scores = self.compute_score(query, passages, data_collator, batch_size, show_progress_bar)
merge_scores = self.compute_score(query, document, data_collator, batch_size, show_progress_bar)

merge_scores_argsort = np.argsort(merge_scores)[::-1]
sorted_passages = []
sorted_document = []
sorted_scores = []
for mid in merge_scores_argsort:
sorted_scores.append(merge_scores[mid])
sorted_passages.append(passages[mid])
sorted_document.append(document[mid])

if return_dict:
return {
'rerank_passages': sorted_passages,
'rerank_document': sorted_document,
'rerank_scores': sorted_scores,
'rerank_ids': merge_scores_argsort.tolist(),
}
else:
return sorted_passages
return sorted_document

def save(self, path: str):
"""
Expand Down
Loading

0 comments on commit 9b768c6

Please sign in to comment.