diff --git a/src/retrievals/models/embedding_auto.py b/src/retrievals/models/embedding_auto.py index 3daf5928..58ed9b05 100644 --- a/src/retrievals/models/embedding_auto.py +++ b/src/retrievals/models/embedding_auto.py @@ -263,7 +263,7 @@ def encode_from_loader( embeddings = embeddings.cpu() all_embeddings.append(embeddings) if convert_to_numpy: - all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + all_embeddings = np.concatenate([emb.numpy() for emb in all_embeddings], axis=0) else: all_embeddings = torch.concat(all_embeddings) return all_embeddings diff --git a/src/retrievals/models/retrieval_auto.py b/src/retrievals/models/retrieval_auto.py index d258bf75..ae8d9ed4 100644 --- a/src/retrievals/models/retrieval_auto.py +++ b/src/retrievals/models/retrieval_auto.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union +from typing import Literal, Optional, Union import faiss import numpy as np @@ -13,7 +13,7 @@ class AutoModelForRetrieval(object): - def __init__(self, method: str = "cosine") -> None: + def __init__(self, method: Literal['cosine', 'knn'] = "cosine") -> None: super().__init__() self.method = method @@ -25,9 +25,9 @@ def similarity_search( top_k: int = 1, batch_size: int = -1, convert_to_numpy: bool = True, - convert_to_pandas: bool = False, **kwargs, ): + self.top_k = top_k if passage_embed is None and index_path is None: logging.warning('Please provide passage_embed for knn/tensor search or index_path for faiss search') return @@ -53,11 +53,20 @@ def similarity_search( else: raise ValueError(f"Only cosine and knn method are supported by similarity_search, while get {self.method}") - if not convert_to_pandas: - return dists, indices + return dists, indices - retrieval = dict({'query': [], 'passage': [], 'labels': []}) - return pd.from_dict(retrieval) + def get_pandas_candidate(self, query_ids, passage_ids, dists, indices): + if isinstance(query_ids, pd.Series): + query_ids = query_ids.values + if isinstance(passage_ids, pd.Series): + passage_ids = passage_ids.values + + retrieval = { + 'query': np.repeat(query_ids, self.top_k), + 'passage': passage_ids[indices.ravel()], + 'scores': dists.ravel(), + } + return pd.DataFrame(retrieval) class EnsembleRetriever(object):