Skip to content

Commit

Permalink
fix: retrieval pandas output (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Apr 16, 2024
1 parent d5788f2 commit 3eac842
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/retrievals/models/embedding_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def encode_from_loader(
embeddings = embeddings.cpu()
all_embeddings.append(embeddings)
if convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
all_embeddings = np.concatenate([emb.numpy() for emb in all_embeddings], axis=0)
else:
all_embeddings = torch.concat(all_embeddings)
return all_embeddings
Expand Down
23 changes: 16 additions & 7 deletions src/retrievals/models/retrieval_auto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Union
from typing import Literal, Optional, Union

import faiss
import numpy as np
Expand All @@ -13,7 +13,7 @@


class AutoModelForRetrieval(object):
def __init__(self, method: str = "cosine") -> None:
def __init__(self, method: Literal['cosine', 'knn'] = "cosine") -> None:
super().__init__()
self.method = method

Expand All @@ -25,9 +25,9 @@ def similarity_search(
top_k: int = 1,
batch_size: int = -1,
convert_to_numpy: bool = True,
convert_to_pandas: bool = False,
**kwargs,
):
self.top_k = top_k
if passage_embed is None and index_path is None:
logging.warning('Please provide passage_embed for knn/tensor search or index_path for faiss search')
return
Expand All @@ -53,11 +53,20 @@ def similarity_search(
else:
raise ValueError(f"Only cosine and knn method are supported by similarity_search, while get {self.method}")

if not convert_to_pandas:
return dists, indices
return dists, indices

retrieval = dict({'query': [], 'passage': [], 'labels': []})
return pd.from_dict(retrieval)
def get_pandas_candidate(self, query_ids, passage_ids, dists, indices):
if isinstance(query_ids, pd.Series):
query_ids = query_ids.values
if isinstance(passage_ids, pd.Series):
passage_ids = passage_ids.values

retrieval = {
'query': np.repeat(query_ids, self.top_k),
'passage': passage_ids[indices.ravel()],
'scores': dists.ravel(),
}
return pd.DataFrame(retrieval)


class EnsembleRetriever(object):
Expand Down

0 comments on commit 3eac842

Please sign in to comment.