From 02aaf6d06bb1b48e25449290583df15dea77b93b Mon Sep 17 00:00:00 2001 From: LongxingTan Date: Mon, 11 Mar 2024 21:35:56 +0800 Subject: [PATCH] fix: update auto_match --- README.md | 4 ++-- src/retrievals/models/embedding_auto.py | 5 +++-- src/retrievals/models/match_auto.py | 12 ++++++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 70be549d..8d38520b 100644 --- a/README.md +++ b/README.md @@ -64,8 +64,8 @@ from retrievals import AutoModelForEmbedding sentences = ["Hello world", "How are you?"] model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2" -model = AutoModelForEmbedding(model_name_or_path, pooling_method="cls") -sentence_embeddings = model.encode(sentences) +model = AutoModelForEmbedding(model_name_or_path, pooling_method="mean", normalize_embeddings=True) +sentence_embeddings = model.encode(sentences, convert_to_tensor=True) print(sentence_embeddings) ``` diff --git a/src/retrievals/models/embedding_auto.py b/src/retrievals/models/embedding_auto.py index 45ebe553..78ba1f5c 100644 --- a/src/retrievals/models/embedding_auto.py +++ b/src/retrievals/models/embedding_auto.py @@ -5,7 +5,6 @@ import numpy as np import torch from numpy import ndarray -from peft import LoraConfig, TaskType, get_peft_model from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset from tqdm.autonotebook import trange @@ -76,7 +75,7 @@ def __init__( generation_args: Dict = None, use_fp16: bool = False, use_lora: bool = False, - peft_config: Optional[LoraConfig] = None, + peft_config=None, device: Optional[str] = None, trust_remote_code: bool = False, ): @@ -110,6 +109,8 @@ def __init__( self.model.half() if use_lora: # peft config and wrapping + from peft import LoraConfig, TaskType, get_peft_model + if not peft_config: raise ValueError("If use_lora is true, please provide a valid peft_config") self.model = get_peft_model(self.model, peft_config) diff --git a/src/retrievals/models/match_auto.py b/src/retrievals/models/match_auto.py index 4598e8bc..474934f5 100644 --- a/src/retrievals/models/match_auto.py +++ b/src/retrievals/models/match_auto.py @@ -16,7 +16,13 @@ def __init__(self, method="cosine") -> None: self.method = method def similarity_search( - self, query_embed: torch.Tensor, passage_embed: torch.Tensor, top_k: int = 1, batch_size: int = 0, **kwargs + self, + query_embed: torch.Tensor, + passage_embed: torch.Tensor, + top_k: int = 1, + batch_size: int = 0, + convert_to_numpy: bool = True, + **kwargs, ): if self.method == "knn": neighbors_model = NearestNeighbors(n_neighbors=top_k, metric="cosine", n_jobs=-1) @@ -25,7 +31,9 @@ def similarity_search( return dists, indices elif self.method == "cosine": - dists, indices = cosine_similarity_search(query_embed, passage_embed, top_k=top_k, batch_size=batch_size) + dists, indices = cosine_similarity_search( + query_embed, passage_embed, top_k=top_k, batch_size=batch_size, convert_to_numpy=convert_to_numpy + ) return dists, indices else: