diff --git a/pyserini/encode/__init__.py b/pyserini/encode/__init__.py index c9e28ee5e..034da0d88 100644 --- a/pyserini/encode/__init__.py +++ b/pyserini/encode/__init__.py @@ -14,11 +14,20 @@ # limitations under the License. # -from ._base import DocumentEncoder, QueryEncoder, JsonlCollectionIterator,\ - RepresentationWriter, FaissRepresentationWriter, JsonlRepresentationWriter, PcaEncoder +from ._base import ( + DocumentEncoder, + MLXDocumentEncoder, + QueryEncoder, + MLXQueryEncoder, + JsonlCollectionIterator, + RepresentationWriter, + FaissRepresentationWriter, + JsonlRepresentationWriter, + PcaEncoder +) from ._ance import AnceEncoder, AnceDocumentEncoder, AnceQueryEncoder from ._auto import AutoQueryEncoder, AutoDocumentEncoder -from ._dpr import DprDocumentEncoder, DprQueryEncoder +from ._dpr import MLXDprDocumentEncoder, DprDocumentEncoder, DprQueryEncoder, MLXDprQueryEncoder from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder from ._aggretriever import AggretrieverDocumentEncoder, AggretrieverQueryEncoder from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder @@ -27,5 +36,12 @@ from ._splade import SpladeQueryEncoder from ._slim import SlimQueryEncoder from ._openai import OpenAIDocumentEncoder, OpenAIQueryEncoder, OPENAI_API_RETRY_DELAY -from ._cosdpr import CosDprEncoder, CosDprDocumentEncoder, CosDprQueryEncoder +from ._cosdpr import ( + CosDprEncoder, + MLXCosDprEncoder, + CosDprDocumentEncoder, + MLXCosDprDocumentEncoder, + CosDprQueryEncoder, + MLXCosDprQueryEncoder +) from ._clip import ClipEncoder, ClipDocumentEncoder \ No newline at end of file diff --git a/pyserini/encode/_base.py b/pyserini/encode/_base.py index 15e4cdb65..ed9890dba 100644 --- a/pyserini/encode/_base.py +++ b/pyserini/encode/_base.py @@ -21,6 +21,8 @@ import numpy as np from tqdm import tqdm +# apple silicon +import mlx as mx class DocumentEncoder: def encode(self, texts, **kwargs): @@ -33,12 +35,26 @@ def _mean_pooling(last_hidden_state, attention_mask): sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings / sum_mask + +class MLXDocumentEncoder: + def encode(self, texts, **kwargs): + pass + @staticmethod + def _mean_pooling(last_hidden_state: mx.array, attention_mask: mx.array): + token_embeddings = last_hidden_state + input_mask_expanded = attention_mask.expand_dims(-1).broadcast_to(token_embeddings.shape).astype(mx.float32) + sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = mx.clip(input_mask_expanded.sum(axis=1), a_min=1e-9, a_max=None) + return sum_embeddings / sum_mask class QueryEncoder: def encode(self, text, **kwargs): pass +class MLXQueryEncoder: + def encode(self, text, **kwargs): + pass class PcaEncoder: def __init__(self, encoder, pca_model_path): diff --git a/pyserini/encode/_cosdpr.py b/pyserini/encode/_cosdpr.py index de2b6cf1f..42ae43731 100644 --- a/pyserini/encode/_cosdpr.py +++ b/pyserini/encode/_cosdpr.py @@ -17,9 +17,18 @@ from typing import Optional import torch +import numpy as np from transformers import PreTrainedModel, BertConfig, BertModel, BertTokenizer -from pyserini.encode import DocumentEncoder, QueryEncoder +# apple silicon +import mlx as mx + +from pyserini.encode import ( + DocumentEncoder, + MLXDocumentEncoder, + QueryEncoder, + MLXQueryEncoder +) class CosDprEncoder(PreTrainedModel): @@ -72,6 +81,51 @@ def forward( return pooled_output +class MLXCosDprEncoder(PreTrainedModel): + config_class = BertConfig + base_model_prefix = 'bert' + load_tf_weights = None + + def __init__(self, config: BertConfig): + super().__init__(config) + self.config = config + self.bert = BertModel(config) + self.linear = mx.nn.Linear(config.hidden_size, config.hidden_size) + self.init_weights() + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (mx.nn.Linear, mx.nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, mx.nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def init_weights(self): + self.bert.init_weights() + self.linear.apply(self._init_weights) + + def forward( + self, + input_ids: mx.array, + attention_mask: Optional[mx.array] = None, + ): + input_shape = input_ids.size() + if attention_mask is None: + attention_mask = ( + mx.ones(input_shape, device=input_ids.default_device()) + if input_ids is None + else (input_ids != self.bert.config.pad_token_id) + ) + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + sequence_output = outputs.last_hidden_state + pooled_output = sequence_output[:, 0, :] + # Lp normalization + pooled_output = self.linear(pooled_output) + pooled_output = mx.core.linalg.norm(pooled_output, p=2, dim=1) + return pooled_output + + + class CosDprDocumentEncoder(DocumentEncoder): def __init__(self, model_name, tokenizer_name=None, device='cuda:0'): self.device = device @@ -114,3 +168,41 @@ def encode(self, query: str, **kwargs): inputs.to(self.device) embeddings = self.model(inputs["input_ids"]).detach().cpu().numpy() return embeddings.flatten() + + +class MLXCosDprDocumentEncoder(MLXDocumentEncoder): + def __init__(self, model_name, tokenizer_name=None): + self.model = MLXCosDprEncoder.from_pretrained(model_name) + self.model.to(self.default_device()) + self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name) + + def encode(self, texts, titles=None, max_length=256, **kwargs): + if titles is not None: + texts = [f'{title} {text}' for title, text in zip(titles, texts)] + inputs = self.tokenizer( + texts, + max_length=max_length, + padding='longest', + truncation=True, + add_special_tokens=True, + return_tensors='pt' + ) + return np.array(self.model(inputs["input_ids"]), copy=False) + + +class MLXCosDprQueryEncoder(MLXQueryEncoder): + def __init__(self, encoder_dir: str, tokenizer_name: str = None, **kwargs): + self.model = MLXCosDprEncoder.from_pretrained(encoder_dir) + self.tokenizer = BertTokenizer.from_pretrained(encoder_dir or tokenizer_name) + + def encode(self, query: str, **kwargs): + inputs = self.tokenizer( + query, + add_special_tokens=True, + return_tensors='pt', + truncation='only_first', + padding='longest', + return_token_type_ids=False, + ) + embeddings = np.array(self.model(inputs["input_ids"])) + return embeddings.flatten() diff --git a/pyserini/encode/_dpr.py b/pyserini/encode/_dpr.py index 9e19a387c..4f8021671 100644 --- a/pyserini/encode/_dpr.py +++ b/pyserini/encode/_dpr.py @@ -14,9 +14,19 @@ # limitations under the License. # -from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer +from transformers import ( + DPRContextEncoder, + DPRContextEncoderTokenizer, + DPRQuestionEncoder, + DPRQuestionEncoderTokenizer +) -from pyserini.encode import DocumentEncoder, QueryEncoder +from pyserini.encode import ( + DocumentEncoder, + MLXDocumentEncoder, + QueryEncoder, + MLXQueryEncoder +) class DprDocumentEncoder(DocumentEncoder): @@ -62,3 +72,48 @@ def encode(self, query: str, **kwargs): input_ids.to(self.device) embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy() return embeddings.flatten() + + +class MLXDprDocumentEncoder(MLXDocumentEncoder): + def __init__(self, model_name, tokenizer_name=None, device='cuda:0'): + self.device = device + self.model = DPRContextEncoder.from_pretrained(model_name) + self.model.to(self.device) + self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(tokenizer_name or model_name) + + def encode(self, texts, titles=None, max_length=256, **kwargs): + if titles: + inputs = self.tokenizer( + titles, + text_pair=texts, + max_length=max_length, + padding='longest', + truncation=True, + add_special_tokens=True, + return_tensors='pt' + ) + else: + inputs = self.tokenizer( + texts, + max_length=max_length, + padding='longest', + truncation=True, + add_special_tokens=True, + return_tensors='pt' + ) + inputs.to(self.device) + return self.model(inputs["input_ids"]).pooler_output.detach().cpu().numpy() + + +class MLXDprQueryEncoder(MLXQueryEncoder): + def __init__(self, model_name: str, tokenizer_name: str = None, device: str = 'cpu'): + self.device = device + self.model = DPRQuestionEncoder.from_pretrained(model_name) + self.model.to(self.device) + self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or model_name) + + def encode(self, query: str, **kwargs): + input_ids = self.tokenizer(query, return_tensors='pt') + input_ids.to(self.device) + embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu().numpy() + return embeddings.flatten() diff --git a/requirements.txt b/requirements.txt index 7738fd42c..df1eacc09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ tiktoken>=0.4.0 pyarrow>=15.0.0 pillow>=10.2.0 pybind11>=2.11.0 +mlx>=0.14.1 diff --git a/tests/test_dpr_mlx.py b/tests/test_dpr_mlx.py new file mode 100644 index 000000000..e69de29bb