From 05834259b31d432fca9f7f71cf1290dc8df35a83 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Sat, 9 Nov 2024 00:18:30 +0100 Subject: [PATCH 01/27] colpali WIP. Works for texts, does not work for images --- fastembed/image/onnx_embedding.py | 12 ++++++++++++ fastembed/image/transform/operators.py | 16 +++++----------- fastembed/text/onnx_embedding.py | 12 ++++++++++++ fastembed/text/onnx_text_model.py | 4 ++++ 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 10b38280..13907436 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -53,6 +53,18 @@ }, "model_file": "model.onnx", }, + { + "model": "akshayballal/colpali-v1.2-merged", + "dim": 128, + "description": "", + "license": "mit", + "size_in_GB": 6.08, + "sources": { + "hf": "akshayballal/colpali-v1.2-merged-onnx", + }, + "additional_files": ["model.onnx_data"], + "model_file": "model.onnx", + }, ] diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 854b0917..22486949 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -50,9 +50,7 @@ def __init__( self.resample = resample def __call__(self, images: List[Image.Image]) -> List[Image.Image]: - return [ - resize(image, size=self.size, resample=self.resample) for image in images - ] + return [resize(image, size=self.size, resample=self.resample) for image in images] class Rescale(Transform): @@ -64,9 +62,7 @@ def __call__(self, images: List[np.ndarray]) -> List[np.ndarray]: class PILtoNDarray(Transform): - def __call__( - self, images: List[Union[Image.Image, np.ndarray]] - ) -> List[np.ndarray]: + def __call__(self, images: List[Union[Image.Image, np.ndarray]]) -> List[np.ndarray]: return [pil2ndarray(image) for image in images] @@ -120,7 +116,7 @@ def _get_convert_to_rgb(transforms: List[Transform], config: Dict[str, Any]): @staticmethod def _get_resize(transforms: List[Transform], config: Dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") - if mode == "CLIPImageProcessor": + if mode == "CLIPImageProcessor" or mode == "SiglipImageProcessor": if config.get("do_resize", False): size = config["size"] if "shortest_edge" in size: @@ -165,7 +161,7 @@ def _get_resize(transforms: List[Transform], config: Dict[str, Any]): @staticmethod def _get_center_crop(transforms: List[Transform], config: Dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") - if mode == "CLIPImageProcessor": + if mode == "CLIPImageProcessor" or mode == "SiglipImageProcessor": if config.get("do_center_crop", False): crop_size = config["crop_size"] if isinstance(crop_size, int): @@ -193,6 +189,4 @@ def _get_rescale(transforms: List[Transform], config: Dict[str, Any]): @staticmethod def _get_normalize(transforms: List[Transform], config: Dict[str, Any]): if config.get("do_normalize", False): - transforms.append( - Normalize(mean=config["image_mean"], std=config["image_std"]) - ) + transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"])) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 74ed5355..ec02b829 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -164,6 +164,18 @@ }, "model_file": "onnx/model.onnx", }, + { + "model": "akshayballal/colpali-v1.2-merged", + "dim": 128, + "description": "", + "license": "mit", + "size_in_GB": 6.08, + "sources": { + "hf": "akshayballal/colpali-v1.2-merged-onnx", + }, + "additional_files": ["model.onnx_data"], + "model_file": "model.onnx", + }, ] diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 36bfcb08..35b430dd 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -79,6 +79,10 @@ def onnx_embed( onnx_input["token_type_ids"] = np.array( [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 ) + if "pixel_values" in input_names: + onnx_input["pixel_values"] = np.zeros( + (np.array(input_ids, dtype=np.int64).shape[0], 3, 448, 448), dtype=np.float32 + ) onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) From f63f77c560aaa69d807592aab3a5697871eef42e Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Tue, 19 Nov 2024 18:01:39 +0100 Subject: [PATCH 02/27] WIP Test preprocessing for image-only --- fastembed/image/onnx_image_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastembed/image/onnx_image_model.py b/fastembed/image/onnx_image_model.py index 86dac0e9..81d4f194 100644 --- a/fastembed/image/onnx_image_model.py +++ b/fastembed/image/onnx_image_model.py @@ -70,6 +70,12 @@ def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: encoded = self.processor(image_files) onnx_input = self._build_onnx_input(encoded) onnx_input = self._preprocess_onnx_input(onnx_input) + onnx_input["input_ids"] = [ + np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) + for _ in onnx_input["input_ids"] + ] + onnx_input["attention_mask"] = [np.array([1] * 1030) for _ in onnx_input["input_ids"]] + model_output = self.model.run(None, onnx_input) embeddings = model_output[0].reshape(len(images), -1) return OnnxOutputContext(model_output=embeddings) From c4bd2c0e73d193dd2c8b87d4d2389d19876dbf76 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 20 Nov 2024 00:03:20 +0100 Subject: [PATCH 03/27] =?UTF-8?q?Image=20part=20ColPali=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastembed/image/image_embedding.py | 3 ++- fastembed/image/onnx_embedding.py | 12 ------------ fastembed/image/onnx_image_model.py | 5 ----- 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index b404cb93..64bba155 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -5,10 +5,11 @@ from fastembed.common import ImageInput, OnnxProvider from fastembed.image.image_embedding_base import ImageEmbeddingBase from fastembed.image.onnx_embedding import OnnxImageEmbedding +from fastembed.image.colpali_model import ColpaliImageModel class ImageEmbedding(ImageEmbeddingBase): - EMBEDDINGS_REGISTRY: List[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding] + EMBEDDINGS_REGISTRY: List[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding, ColpaliImageModel] @classmethod def list_supported_models(cls) -> List[Dict[str, Any]]: diff --git a/fastembed/image/onnx_embedding.py b/fastembed/image/onnx_embedding.py index 13907436..10b38280 100644 --- a/fastembed/image/onnx_embedding.py +++ b/fastembed/image/onnx_embedding.py @@ -53,18 +53,6 @@ }, "model_file": "model.onnx", }, - { - "model": "akshayballal/colpali-v1.2-merged", - "dim": 128, - "description": "", - "license": "mit", - "size_in_GB": 6.08, - "sources": { - "hf": "akshayballal/colpali-v1.2-merged-onnx", - }, - "additional_files": ["model.onnx_data"], - "model_file": "model.onnx", - }, ] diff --git a/fastembed/image/onnx_image_model.py b/fastembed/image/onnx_image_model.py index 81d4f194..af5c8c45 100644 --- a/fastembed/image/onnx_image_model.py +++ b/fastembed/image/onnx_image_model.py @@ -70,11 +70,6 @@ def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: encoded = self.processor(image_files) onnx_input = self._build_onnx_input(encoded) onnx_input = self._preprocess_onnx_input(onnx_input) - onnx_input["input_ids"] = [ - np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) - for _ in onnx_input["input_ids"] - ] - onnx_input["attention_mask"] = [np.array([1] * 1030) for _ in onnx_input["input_ids"]] model_output = self.model.run(None, onnx_input) embeddings = model_output[0].reshape(len(images), -1) From ee62c69935d01bf6e5028622e736e3f5b68bc608 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 20 Nov 2024 01:38:50 +0100 Subject: [PATCH 04/27] WIP Tokenizer part ColPali --- fastembed/common/preprocessor_utils.py | 1 + fastembed/image/colpali_model.py | 66 ++++++++++++++++++++ fastembed/text/colpali_model.py | 85 ++++++++++++++++++++++++++ fastembed/text/onnx_embedding.py | 12 ---- fastembed/text/onnx_text_model.py | 12 ---- 5 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 fastembed/image/colpali_model.py create mode 100644 fastembed/text/colpali_model.py diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index db2432f3..0cc0d528 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -7,6 +7,7 @@ def load_special_tokens(model_dir: Path) -> dict: + print(model_dir) tokens_map_path = model_dir / "special_tokens_map.json" if not tokens_map_path.exists(): raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") diff --git a/fastembed/image/colpali_model.py b/fastembed/image/colpali_model.py new file mode 100644 index 00000000..7db452c7 --- /dev/null +++ b/fastembed/image/colpali_model.py @@ -0,0 +1,66 @@ +import contextlib +from typing import Any, Dict, Iterable, List + +import numpy as np +from PIL import Image + +from fastembed.common import ImageInput +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.image.onnx_embedding import OnnxImageEmbedding + +supported_onnx_models = [ + { + "model": "akshayballal/colpali-v1.2-merged", + "dim": 128, + "description": "", + "license": "mit", + "size_in_GB": 6.08, + "sources": { + "hf": "akshayballal/colpali-v1.2-merged-onnx", + }, + "additional_files": ["model.onnx_data"], + "model_file": "model.onnx", + } +] + + +class ColpaliImageModel(OnnxImageEmbedding): + def _preprocess_onnx_input( + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: + empty_text_placeholder = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) + even_attention_mask = np.array([1] * 1030) + onnx_input["input_ids"] = np.array( + [empty_text_placeholder for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array( + [even_attention_mask for _ in onnx_input["input_ids"]] + ) + return onnx_input + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """ + Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_onnx_models + + def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: + return output.model_output.astype(np.float32) + + def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in images + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_input(encoded) + onnx_input = self._preprocess_onnx_input(onnx_input) + + model_output = self.model.run(None, onnx_input) + embeddings = model_output[0].reshape(len(images), -1, supported_onnx_models[0]["dim"]) + return OnnxOutputContext(model_output=embeddings) diff --git a/fastembed/text/colpali_model.py b/fastembed/text/colpali_model.py new file mode 100644 index 00000000..84befb18 --- /dev/null +++ b/fastembed/text/colpali_model.py @@ -0,0 +1,85 @@ +from typing import Any, Dict, Iterable, List + +import numpy as np + +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.text.onnx_embedding import OnnxTextEmbedding + +supported_onnx_models = [ + { + "model": "akshayballal/colpali-v1.2-merged", + "dim": 128, + "description": "", + "license": "mit", + "size_in_GB": 6.08, + "sources": { + "hf": "akshayballal/colpali-v1.2-merged-onnx", + }, + "additional_files": [ + "model.onnx_data", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "config.json", + "preprocessor_config.json", + ], + "model_file": "model.onnx", + } +] + + +class ColpaliTextModel(OnnxTextEmbedding): + query_prefix = "Query: " + bos_token = "" + + def _preprocess_onnx_input( + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: + empty_image_placeholder = np.zeros((3, 448, 448), dtype=np.float32) + onnx_input["pixel_values"] = np.array( + [empty_image_placeholder for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) + return onnx_input + + @classmethod + def list_supported_models(cls) -> List[Dict[str, Any]]: + """ + Lists the supported models. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_onnx_models + + def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: + return output.model_output.astype(np.float32) + + def _preprocess_queries(self, documents: List[str]): + texts_query: List[str] = [] + + for query in documents: + query = self.bos_token + self.query_prefix + query + query += "\n" + + texts_query.append(query) + return texts_query + + def onnx_embed( + self, + documents: List[str], + **kwargs, + ) -> OnnxOutputContext: + documents = self._preprocess_queries(documents) + encoded = self.tokenize(documents, **kwargs) + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} + onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) + print(onnx_input) + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=onnx_input.get("attention_mask", attention_mask), + input_ids=onnx_input.get("input_ids", input_ids), + ) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index ec02b829..74ed5355 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -164,18 +164,6 @@ }, "model_file": "onnx/model.onnx", }, - { - "model": "akshayballal/colpali-v1.2-merged", - "dim": 128, - "description": "", - "license": "mit", - "size_in_GB": 6.08, - "sources": { - "hf": "akshayballal/colpali-v1.2-merged-onnx", - }, - "additional_files": ["model.onnx_data"], - "model_file": "model.onnx", - }, ] diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 35b430dd..04720638 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -69,21 +69,9 @@ def onnx_embed( encoded = self.tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) - input_names = {node.name for node in self.model.get_inputs()} onnx_input = { "input_ids": np.array(input_ids, dtype=np.int64), } - if "attention_mask" in input_names: - onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64) - if "token_type_ids" in input_names: - onnx_input["token_type_ids"] = np.array( - [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 - ) - if "pixel_values" in input_names: - onnx_input["pixel_values"] = np.zeros( - (np.array(input_ids, dtype=np.int64).shape[0], 3, 448, 448), dtype=np.float32 - ) - onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) From 34f557c4d2a80b894838ac5903a9d0841d5bca62 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 20 Nov 2024 15:44:04 +0100 Subject: [PATCH 05/27] Done: Tokenizer part ColPali --- fastembed/common/preprocessor_utils.py | 1 - fastembed/image/image_embedding.py | 2 +- fastembed/text/colpali_model.py | 9 ++++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index 60ff7ba0..228afcad 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -7,7 +7,6 @@ def load_special_tokens(model_dir: Path) -> dict: - print(model_dir) tokens_map_path = model_dir / "special_tokens_map.json" if not tokens_map_path.exists(): raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index 67039fc6..f481f486 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -9,7 +9,7 @@ class ImageEmbedding(ImageEmbeddingBase): - EMBEDDINGS_REGISTRY: List[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding, ColpaliImageModel] + EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding, ColpaliImageModel] @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: diff --git a/fastembed/text/colpali_model.py b/fastembed/text/colpali_model.py index 84befb18..977a1735 100644 --- a/fastembed/text/colpali_model.py +++ b/fastembed/text/colpali_model.py @@ -31,6 +31,7 @@ class ColpaliTextModel(OnnxTextEmbedding): query_prefix = "Query: " bos_token = "" + pad_token = "" def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], **kwargs @@ -59,7 +60,7 @@ def _preprocess_queries(self, documents: List[str]): texts_query: List[str] = [] for query in documents: - query = self.bos_token + self.query_prefix + query + query = self.bos_token + self.query_prefix + query + self.pad_token * 10 query += "\n" texts_query.append(query) @@ -71,12 +72,14 @@ def onnx_embed( **kwargs, ) -> OnnxOutputContext: documents = self._preprocess_queries(documents) + self.tokenizer.enable_truncation(max_length=10000) encoded = self.tokenize(documents, **kwargs) - input_ids = np.array([e.ids for e in encoded]) + input_ids = np.array([[2, 9413] + e.ids[2:] for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) - print(onnx_input) + onnx_input["attention_mask"] = attention_mask model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) return OnnxOutputContext( model_output=model_output[0], From 9274c7d556863851ad1939128b42e98ad801530f Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 20 Nov 2024 16:30:15 +0100 Subject: [PATCH 06/27] Done: Tests --- fastembed/image/colpali_model.py | 4 ++-- fastembed/text/colpali_model.py | 2 +- tests/test_image_onnx_embeddings.py | 17 +++++++++++++---- tests/test_text_onnx_embeddings.py | 24 ++++++++++++++++++++---- 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/fastembed/image/colpali_model.py b/fastembed/image/colpali_model.py index 7db452c7..ed461a50 100644 --- a/fastembed/image/colpali_model.py +++ b/fastembed/image/colpali_model.py @@ -11,7 +11,7 @@ supported_onnx_models = [ { "model": "akshayballal/colpali-v1.2-merged", - "dim": 128, + "dim": (1030, 128), "description": "", "license": "mit", "size_in_GB": 6.08, @@ -62,5 +62,5 @@ def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: onnx_input = self._preprocess_onnx_input(onnx_input) model_output = self.model.run(None, onnx_input) - embeddings = model_output[0].reshape(len(images), -1, supported_onnx_models[0]["dim"]) + embeddings = model_output[0].reshape(len(images), *supported_onnx_models[0]["dim"]) return OnnxOutputContext(model_output=embeddings) diff --git a/fastembed/text/colpali_model.py b/fastembed/text/colpali_model.py index 977a1735..dc5e086f 100644 --- a/fastembed/text/colpali_model.py +++ b/fastembed/text/colpali_model.py @@ -8,7 +8,7 @@ supported_onnx_models = [ { "model": "akshayballal/colpali-v1.2-merged", - "dim": 128, + "dim": (16, 128), "description": "", "license": "mit", "size_in_GB": 6.08, diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 78194caf..90ca0654 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -21,6 +21,9 @@ "Qdrant/Unicom-ViT-B-32": np.array( [0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186] ), + "akshayballal/colpali-v1.2-merged": np.array( + [0.01533, 0.05118, 0.05948, 0.02583, -0.06128, -0.02682] + ), } @@ -43,13 +46,19 @@ def test_embedding(): ] embeddings = list(model.embed(images)) embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (len(images), dim) canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] - assert np.allclose( - embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] + if isinstance(dim, tuple): + assert embeddings.shape == (len(images), *dim) + assert np.allclose( + embeddings[0][0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc["model"] + else: + assert embeddings.shape == (len(images), dim) + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc["model"] assert np.allclose(embeddings[1], embeddings[2]), model_desc["model"] diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index f576330c..28ac7553 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -64,6 +64,15 @@ ), "snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]), "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), + "akshayballal/colpali-v1.2-merged": [ + 0.1581, + -0.03748, + 0.09265, + -0.0002161, + 0.0762, + 0.02055, + 0.09937, + ], } @@ -80,12 +89,19 @@ def test_embedding(): docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) - assert embeddings.shape == (2, dim) canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] - assert np.allclose( - embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] + + if isinstance(dim, tuple): + assert embeddings.shape == (len(docs), *dim) + assert np.allclose( + embeddings[0][0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc["model"] + else: + assert embeddings.shape == (len(docs), dim) + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc["model"] if is_ci: delete_model_cache(model.model._model_dir) From d4f4e5ad1c5455e1a167c80d98950581245541f2 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 20 Nov 2024 16:43:23 +0100 Subject: [PATCH 07/27] Remove unnecessary changes --- fastembed/image/onnx_image_model.py | 1 - fastembed/image/transform/operators.py | 1 - fastembed/text/onnx_text_model.py | 7 +++++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/fastembed/image/onnx_image_model.py b/fastembed/image/onnx_image_model.py index 3c541c1b..895f006f 100644 --- a/fastembed/image/onnx_image_model.py +++ b/fastembed/image/onnx_image_model.py @@ -70,7 +70,6 @@ def onnx_embed(self, images: list[ImageInput], **kwargs) -> OnnxOutputContext: encoded = self.processor(image_files) onnx_input = self._build_onnx_input(encoded) onnx_input = self._preprocess_onnx_input(onnx_input) - model_output = self.model.run(None, onnx_input) embeddings = model_output[0].reshape(len(images), -1) return OnnxOutputContext(model_output=embeddings) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 69f17cfa..494fa0d0 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -49,7 +49,6 @@ def __init__( self.size = size self.resample = resample - def __call__(self, images: list[Image.Image]) -> list[Image.Image]: return [resize(image, size=self.size, resample=self.resample) for image in images] diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 176788ea..ba3e1516 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -69,9 +69,16 @@ def onnx_embed( encoded = self.tokenize(documents, **kwargs) input_ids = np.array([e.ids for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) + input_names = {node.name for node in self.model.get_inputs()} onnx_input = { "input_ids": np.array(input_ids, dtype=np.int64), } + if "attention_mask" in input_names: + onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64) + if "token_type_ids" in input_names: + onnx_input["token_type_ids"] = np.array( + [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 + ) onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) From 7ca807ec9e908ab0d0bfc1d41fb67168fed8f213 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Thu, 21 Nov 2024 11:50:51 +0100 Subject: [PATCH 08/27] Description changes --- fastembed/image/colpali_model.py | 2 +- fastembed/text/colpali_model.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/fastembed/image/colpali_model.py b/fastembed/image/colpali_model.py index ed461a50..0934c1d3 100644 --- a/fastembed/image/colpali_model.py +++ b/fastembed/image/colpali_model.py @@ -12,7 +12,7 @@ { "model": "akshayballal/colpali-v1.2-merged", "dim": (1030, 128), - "description": "", + "description": "Image embeddings, Unimodal (image), Aligned to text latent space via PaliGemma-3B, 512 patches max, 2024.", "license": "mit", "size_in_GB": 6.08, "sources": { diff --git a/fastembed/text/colpali_model.py b/fastembed/text/colpali_model.py index dc5e086f..de7c8ced 100644 --- a/fastembed/text/colpali_model.py +++ b/fastembed/text/colpali_model.py @@ -9,7 +9,7 @@ { "model": "akshayballal/colpali-v1.2-merged", "dim": (16, 128), - "description": "", + "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", "license": "mit", "size_in_GB": 6.08, "sources": { @@ -19,9 +19,7 @@ "model.onnx_data", "tokenizer.json", "tokenizer_config.json", - "special_tokens_map.json", "config.json", - "preprocessor_config.json", ], "model_file": "model.onnx", } From 317ccecaa2f2a0e6715901fcf30ab73ebd6d5409 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Thu, 21 Nov 2024 23:34:29 +0100 Subject: [PATCH 09/27] Refactoring of magic numbers and values --- fastembed/image/colpali_model.py | 9 +++++---- fastembed/text/colpali_model.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fastembed/image/colpali_model.py b/fastembed/image/colpali_model.py index 0934c1d3..fa0209a2 100644 --- a/fastembed/image/colpali_model.py +++ b/fastembed/image/colpali_model.py @@ -25,16 +25,17 @@ class ColpaliImageModel(OnnxImageEmbedding): + empty_text_placeholder = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) + even_attention_mask = np.array([1] * 1030) + def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], **kwargs ) -> Dict[str, np.ndarray]: - empty_text_placeholder = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) - even_attention_mask = np.array([1] * 1030) onnx_input["input_ids"] = np.array( - [empty_text_placeholder for _ in onnx_input["input_ids"]] + [self.empty_text_placeholder for _ in onnx_input["input_ids"]] ) onnx_input["attention_mask"] = np.array( - [even_attention_mask for _ in onnx_input["input_ids"]] + [self.even_attention_mask for _ in onnx_input["input_ids"]] ) return onnx_input diff --git a/fastembed/text/colpali_model.py b/fastembed/text/colpali_model.py index de7c8ced..0320922c 100644 --- a/fastembed/text/colpali_model.py +++ b/fastembed/text/colpali_model.py @@ -30,11 +30,13 @@ class ColpaliTextModel(OnnxTextEmbedding): query_prefix = "Query: " bos_token = "" pad_token = "" + query_tokens = [2, 9413] + image_placeholder_size = (3, 448, 448) def _preprocess_onnx_input( self, onnx_input: Dict[str, np.ndarray], **kwargs ) -> Dict[str, np.ndarray]: - empty_image_placeholder = np.zeros((3, 448, 448), dtype=np.float32) + empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) onnx_input["pixel_values"] = np.array( [empty_image_placeholder for _ in onnx_input["input_ids"]] ) @@ -72,7 +74,7 @@ def onnx_embed( documents = self._preprocess_queries(documents) self.tokenizer.enable_truncation(max_length=10000) encoded = self.tokenize(documents, **kwargs) - input_ids = np.array([[2, 9413] + e.ids[2:] for e in encoded]) + input_ids = np.array([self.query_tokens + e.ids[2:] for e in encoded]) attention_mask = np.array([e.attention_mask for e in encoded]) onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} From d581de9090acaa2ba2fcb60a9151d934ce937686 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 27 Nov 2024 10:51:07 +0100 Subject: [PATCH 10/27] Refactoring to late interaction class --- fastembed/late_interaction/colbert.py | 2 +- fastembed/late_interaction/colpali.py | 193 ++++++++++++++++++ .../late_interaction_image_embedding.py | 113 ++++++++++ .../late_interaction_image_embedding_base.py | 62 ++++++ .../late_interaction_text_embedding.py | 2 +- 5 files changed, 370 insertions(+), 2 deletions(-) create mode 100644 fastembed/late_interaction/colpali.py create mode 100644 fastembed/late_interaction/late_interaction_image_embedding.py create mode 100644 fastembed/late_interaction/late_interaction_image_embedding_base.py diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 4d65fc29..27de10ad 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -7,7 +7,7 @@ from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir -from fastembed.late_interaction.late_interaction_embedding_base import ( +from fastembed.late_interaction.late_interaction_text_embedding_base import ( LateInteractionTextEmbeddingBase, ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker diff --git a/fastembed/late_interaction/colpali.py b/fastembed/late_interaction/colpali.py new file mode 100644 index 00000000..ddf4d700 --- /dev/null +++ b/fastembed/late_interaction/colpali.py @@ -0,0 +1,193 @@ +from typing import Any, Iterable, Optional, Sequence, Union + +import numpy as np + +from fastembed.common import OnnxProvider +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.image.onnx_image_model import OnnxImageModel +from fastembed.late_interaction.late_interaction_image_embedding_base import ( + LateInteractionImageEmbeddingBase, +) +from PIL import Image +from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker +import contextlib +from fastembed.common import ImageInput +from fastembed.common.preprocessor_utils import load_preprocessor + + +supported_colpali_models = [ + { + "model": "akshayballal/colpali-v1.2-merged", + "dim": (16, 128), + "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", + "license": "mit", + "size_in_GB": 6.08, + "sources": { + "hf": "akshayballal/colpali-v1.2-merged-onnx", + }, + "additional_files": [ + "model.onnx_data", + "tokenizer.json", + "tokenizer_config.json", + "config.json", + ], + "model_file": "model.onnx", + } +] + + +class ColPali( + LateInteractionImageEmbeddingBase, OnnxTextModel[np.ndarray], OnnxImageModel[np.array] +): + DOCUMENT_MARKER_TOKEN_ID = 2 + + QUERY_PREFIX = "Query: " + BOS_TOKEN = "" + PAD_TOKEN = "" + QUERY_MARKER_TOKEN_ID = [2, 9413] + image_placeholder_size = (3, 448, 448) + EMPTY_TEXT_PLACEHOLDER = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) + EVEN_ATTENTION_MASK = np.array([1] * 1030) + + def _post_process_onnx_output( + self, + output: OnnxOutputContext, + ) -> Iterable[np.ndarray]: + return output.model_output.astype(np.float32) + + def _preprocess_image_input( + self, onnx_input: dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any + ) -> dict[str, np.ndarray]: + if is_doc: + onnx_input["input_ids"] = np.array( + [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array( + [self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]] + ) + return onnx_input + else: + empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) + onnx_input["pixel_values"] = np.array( + [empty_image_placeholder for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) + return onnx_input + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_colpali_models + + def _preprocess_queries(self, documents: list[str]): + texts_query: list[str] = [] + + for query in documents: + query = self.bos_token + self.query_prefix + query + self.pad_token * 10 + query += "\n" + + texts_query.append(query) + return texts_query + + def _preprocess_query_input( + self, inputs: list[Union[str]], **kwargs: Any + ) -> dict[str, np.ndarray]: + documents = self._preprocess_queries(inputs) + encoded = self.tokenize(documents, **kwargs) + input_ids = np.array([self.query_tokens + e.ids[2:] for e in encoded]) + + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} + onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) + onnx_input["attention_mask"] = attention_mask + return onnx_input + + def _preprocess_images_input( + self, inputs: list[Union[ImageInput]], **kwargs: Any + ) -> dict[str, np.ndarray]: + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in inputs + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_input(encoded) + onnx_input = self._preprocess_image_input(onnx_input, **kwargs) + return onnx_input + + def embed( + self, + inputs: list[Union[str, Image]], + is_doc: bool = False, + **kwargs, + ) -> OnnxOutputContext: + if is_doc: + onnx_input = self._preprocess_query_input(inputs, **kwargs) + else: + onnx_input = self._preprocess_images_input(inputs, **kwargs) + + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=onnx_input.get("attention_mask", onnx_input["attention_mask"]), + input_ids=onnx_input.get("input_ids", onnx_input["input_ids"]), + ) + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + **kwargs, + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + self.model_description = self._get_model_description(model_name) + self._model_dir = self.download_model( + self.model_description, self.cache_dir, local_files_only=self._local_files_only + ) + self.providers = providers + self.lazy_load = lazy_load + self.cuda = cuda + self.device_ids = device_ids + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + else: + self.device_id = None + self.load_onnx_model() + self.processor = load_preprocessor(model_dir=self._model_dir) + + # self.tokenizer.enable_truncation(max_length=10000) + + def load_onnx_model(self) -> None: + """ + Load the onnx model. + """ + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description["model_file"], + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + ) + + +class ColPaliEmbeddingWorker(TextEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + return ColPali( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/late_interaction/late_interaction_image_embedding.py b/fastembed/late_interaction/late_interaction_image_embedding.py new file mode 100644 index 00000000..fffddbd0 --- /dev/null +++ b/fastembed/late_interaction/late_interaction_image_embedding.py @@ -0,0 +1,113 @@ +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np + +from fastembed.common import OnnxProvider +from fastembed.late_interaction.colpali import ColPali +from fastembed.late_interaction.late_interaction_image_embedding_base import ( + LateInteractionImageEmbeddingBase, +) + + +class LateInteractionTextEmbedding(LateInteractionImageEmbeddingBase): + EMBEDDINGS_REGISTRY: list[Type[LateInteractionImageEmbeddingBase]] = [ColPali] + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """ + Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + + Example: + ``` + [ + { + "model": "colbert-ir/colbertv2.0", + "dim": 128, + "description": "Late interaction model", + "license": "mit", + "size_in_GB": 0.44, + "sources": { + "hf": "colbert-ir/colbertv2.0", + }, + "model_file": "model.onnx", + }, + ] + ``` + """ + result = [] + for embedding in cls.EMBEDDINGS_REGISTRY: + result.extend(embedding.list_supported_models()) + return result + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + **kwargs, + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: + supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() + if any(model_name.lower() == model["model"].lower() for model in supported_models): + self.model = EMBEDDING_MODEL_TYPE( + model_name, + cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + **kwargs, + ) + return + + raise ValueError( + f"Model {model_name} is not supported in LateInteractionTextEmbedding." + "Please check the supported models using `LateInteractionTextEmbedding.list_supported_models()`" + ) + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self.model.embed(documents, batch_size, parallel, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds queries + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + yield from self.model.query_embed(query, **kwargs) diff --git a/fastembed/late_interaction/late_interaction_image_embedding_base.py b/fastembed/late_interaction/late_interaction_image_embedding_base.py new file mode 100644 index 00000000..e4ae8996 --- /dev/null +++ b/fastembed/late_interaction/late_interaction_image_embedding_base.py @@ -0,0 +1,62 @@ +from typing import Iterable, Optional, Union + +import numpy as np + +from fastembed.common.model_management import ModelManagement +from fastembed.common.types import ImageInput + + +class LateInteractionImageEmbeddingBase(ModelManagement): + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + self.cache_dir = cache_dir + self.threads = threads + self._local_files_only = kwargs.pop("local_files_only", False) + + def embed( + self, + images: Union[ImageInput, Iterable[ImageInput], str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + is_doc: bool = False, + **kwargs, + ) -> Iterable[np.ndarray]: + raise NotImplementedError() + + def image_embed(self, images: Iterable[ImageInput], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds a list of image passages into a list of embeddings. + + Args: + images (Iterable[str]): The list of images to embed. + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + yield from self.embed(images, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds queries + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + if isinstance(query, str): + yield from self.embed([query], is_doc=True, **kwargs) + if isinstance(query, Iterable): + yield from self.embed(query, is_doc=True, **kwargs) diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 58c88411..dc7a719c 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -5,7 +5,7 @@ from fastembed.common import OnnxProvider from fastembed.late_interaction.colbert import Colbert from fastembed.late_interaction.jina_colbert import JinaColbert -from fastembed.late_interaction.late_interaction_embedding_base import ( +from fastembed.late_interaction.late_interaction_text_embedding_base import ( LateInteractionTextEmbeddingBase, ) From e43f680a24c0015f19098bbb240911f06e9c6120 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Wed, 27 Nov 2024 10:51:14 +0100 Subject: [PATCH 11/27] Refactoring to late interaction class --- ...edding_base.py => late_interaction_text_embedding_base.py} | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) rename fastembed/late_interaction/{late_interaction_embedding_base.py => late_interaction_text_embedding_base.py} (94%) diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_text_embedding_base.py similarity index 94% rename from fastembed/late_interaction/late_interaction_embedding_base.py rename to fastembed/late_interaction/late_interaction_text_embedding_base.py index 64fba498..2a587e01 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_text_embedding_base.py @@ -42,9 +42,7 @@ def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: # This is model-specific, so that different models can have specialized implementations yield from self.embed(texts, **kwargs) - def query_embed( - self, query: Union[str, Iterable[str]], **kwargs - ) -> Iterable[np.ndarray]: + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: """ Embeds queries From 423bb28fa5f7123fb481797a65b970c01b8a974a Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 27 Nov 2024 12:03:53 +0200 Subject: [PATCH 12/27] fix: Minor fix related to image.image --- fastembed/image/transform/functional.py | 2 +- fastembed/late_interaction/colpali.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 70da2a22..380782f7 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -96,7 +96,7 @@ def normalize( def resize( - image: Image, + image: Image.Image, size: Union[int, tuple[int, int]], resample: Image.Resampling = Image.Resampling.BILINEAR, ) -> Image: diff --git a/fastembed/late_interaction/colpali.py b/fastembed/late_interaction/colpali.py index ddf4d700..b3f429d8 100644 --- a/fastembed/late_interaction/colpali.py +++ b/fastembed/late_interaction/colpali.py @@ -121,7 +121,7 @@ def _preprocess_images_input( def embed( self, - inputs: list[Union[str, Image]], + inputs: list[Union[str, Image.Image]], is_doc: bool = False, **kwargs, ) -> OnnxOutputContext: From dcae3ab3e89274c61061628dccea35f23bf1fefc Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Thu, 28 Nov 2024 11:33:36 +0100 Subject: [PATCH 13/27] Moved colpali to late_interaction --- fastembed/common/model_management.py | 2 + fastembed/late_interaction/colpali.py | 130 ++++++++++++------ pyproject.toml | 2 +- ... test_late_interaction_text_embeddings.py} | 0 4 files changed, 89 insertions(+), 45 deletions(-) rename tests/{test_late_interaction_embeddings.py => test_late_interaction_text_embeddings.py} (100%) diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 5ce95a49..d7423a99 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -119,6 +119,8 @@ def download_files_from_huggingface( if extra_patterns is not None: allow_patterns.extend(extra_patterns) + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + return snapshot_download( repo_id=hf_source_repo, allow_patterns=allow_patterns, diff --git a/fastembed/late_interaction/colpali.py b/fastembed/late_interaction/colpali.py index ddf4d700..9ef9c880 100644 --- a/fastembed/late_interaction/colpali.py +++ b/fastembed/late_interaction/colpali.py @@ -1,7 +1,7 @@ -from typing import Any, Iterable, Optional, Sequence, Union +from typing import Any, Iterable, Optional, Sequence, Union, List, Dict import numpy as np - +from sys import maxsize from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.image.onnx_image_model import OnnxImageModel @@ -18,7 +18,7 @@ supported_colpali_models = [ { "model": "akshayballal/colpali-v1.2-merged", - "dim": (16, 128), + "dim": 128, "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", "license": "mit", "size_in_GB": 6.08, @@ -40,7 +40,6 @@ class ColPali( LateInteractionImageEmbeddingBase, OnnxTextModel[np.ndarray], OnnxImageModel[np.array] ): DOCUMENT_MARKER_TOKEN_ID = 2 - QUERY_PREFIX = "Query: " BOS_TOKEN = "" PAD_TOKEN = "" @@ -55,25 +54,6 @@ def _post_process_onnx_output( ) -> Iterable[np.ndarray]: return output.model_output.astype(np.float32) - def _preprocess_image_input( - self, onnx_input: dict[str, np.ndarray], is_doc: bool = True, **kwargs: Any - ) -> dict[str, np.ndarray]: - if is_doc: - onnx_input["input_ids"] = np.array( - [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] - ) - onnx_input["attention_mask"] = np.array( - [self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]] - ) - return onnx_input - else: - empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) - onnx_input["pixel_values"] = np.array( - [empty_image_placeholder for _ in onnx_input["input_ids"]] - ) - onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) - return onnx_input - @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: """Lists the supported models. @@ -87,25 +67,12 @@ def _preprocess_queries(self, documents: list[str]): texts_query: list[str] = [] for query in documents: - query = self.bos_token + self.query_prefix + query + self.pad_token * 10 + query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 query += "\n" texts_query.append(query) return texts_query - def _preprocess_query_input( - self, inputs: list[Union[str]], **kwargs: Any - ) -> dict[str, np.ndarray]: - documents = self._preprocess_queries(inputs) - encoded = self.tokenize(documents, **kwargs) - input_ids = np.array([self.query_tokens + e.ids[2:] for e in encoded]) - - attention_mask = np.array([e.attention_mask for e in encoded]) - onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} - onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) - onnx_input["attention_mask"] = attention_mask - return onnx_input - def _preprocess_images_input( self, inputs: list[Union[ImageInput]], **kwargs: Any ) -> dict[str, np.ndarray]: @@ -121,22 +88,98 @@ def _preprocess_images_input( def embed( self, - inputs: list[Union[str, Image]], + inputs: Union[ImageInput, str], + batch_size: int = 16, + parallel: Optional[int] = None, is_doc: bool = False, **kwargs, ) -> OnnxOutputContext: if is_doc: - onnx_input = self._preprocess_query_input(inputs, **kwargs) + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=inputs, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + else: + # onnx_input = self._preprocess_images_input(inputs, **kwargs) + yield from self._embed_images( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + images=inputs, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + def onnx_embed(self, inputs: Union[ImageInput, str], **kwargs) -> OnnxOutputContext: + if isinstance(inputs[0], str): + return self.onnx_embed_text(inputs, **kwargs) else: - onnx_input = self._preprocess_images_input(inputs, **kwargs) + return self.onnx_embed_image(inputs, **kwargs) + def onnx_embed_image(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in images + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_input(encoded) + onnx_input = self._preprocess_onnx_image_input(onnx_input) + model_output = self.model.run(None, onnx_input) + embeddings = model_output[0].reshape(len(images), -1, self.model_description["dim"]) + return OnnxOutputContext(model_output=embeddings) + + def onnx_embed_text( + self, + documents: List[str], + **kwargs, + ) -> OnnxOutputContext: + documents = self._preprocess_queries(documents) + encoded = self.tokenize(documents, **kwargs) + input_ids = np.array([self.QUERY_MARKER_TOKEN_ID + e.ids[2:] for e in encoded]) + + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} + onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs) + onnx_input["attention_mask"] = attention_mask model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) return OnnxOutputContext( model_output=model_output[0], - attention_mask=onnx_input.get("attention_mask", onnx_input["attention_mask"]), - input_ids=onnx_input.get("input_ids", onnx_input["input_ids"]), + attention_mask=onnx_input.get("attention_mask", attention_mask), + input_ids=onnx_input.get("input_ids", input_ids), ) + def _preprocess_onnx_image_input( + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: + onnx_input["input_ids"] = np.array( + [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array( + [self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]] + ) + return onnx_input + + def _preprocess_onnx_text_input( + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: + empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) + onnx_input["pixel_values"] = np.array( + [empty_image_placeholder for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) + return onnx_input + def __init__( self, model_name: str, @@ -166,8 +209,7 @@ def __init__( self.device_id = None self.load_onnx_model() self.processor = load_preprocessor(model_dir=self._model_dir) - - # self.tokenizer.enable_truncation(max_length=10000) + self.tokenizer.enable_truncation(max_length=maxsize) def load_onnx_model(self) -> None: """ diff --git a/pyproject.toml b/pyproject.toml index 39ecf6c6..64475d5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ onnxruntime = ">=1.17.0,<1.20.0" tqdm = "^4.66" requests = "^2.31" tokenizers = ">=0.15,<1.0" -huggingface-hub = ">=0.20,<1.0" +huggingface-hub = {version = ">=0.20,<1.0", extras = ["hf_transfer"]} loguru = "^0.7.2" numpy = [ { version = ">=1.21", python = "<3.12" }, diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_text_embeddings.py similarity index 100% rename from tests/test_late_interaction_embeddings.py rename to tests/test_late_interaction_text_embeddings.py From 62a065eb4ae228bed81b25e4164d194e092d43c3 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Thu, 28 Nov 2024 11:35:53 +0100 Subject: [PATCH 14/27] Removed colpali from text/image --- fastembed/image/colpali_model.py | 67 ----------------------- fastembed/image/image_embedding.py | 3 +- fastembed/text/colpali_model.py | 88 ------------------------------ 3 files changed, 1 insertion(+), 157 deletions(-) delete mode 100644 fastembed/image/colpali_model.py delete mode 100644 fastembed/text/colpali_model.py diff --git a/fastembed/image/colpali_model.py b/fastembed/image/colpali_model.py deleted file mode 100644 index fa0209a2..00000000 --- a/fastembed/image/colpali_model.py +++ /dev/null @@ -1,67 +0,0 @@ -import contextlib -from typing import Any, Dict, Iterable, List - -import numpy as np -from PIL import Image - -from fastembed.common import ImageInput -from fastembed.common.onnx_model import OnnxOutputContext -from fastembed.image.onnx_embedding import OnnxImageEmbedding - -supported_onnx_models = [ - { - "model": "akshayballal/colpali-v1.2-merged", - "dim": (1030, 128), - "description": "Image embeddings, Unimodal (image), Aligned to text latent space via PaliGemma-3B, 512 patches max, 2024.", - "license": "mit", - "size_in_GB": 6.08, - "sources": { - "hf": "akshayballal/colpali-v1.2-merged-onnx", - }, - "additional_files": ["model.onnx_data"], - "model_file": "model.onnx", - } -] - - -class ColpaliImageModel(OnnxImageEmbedding): - empty_text_placeholder = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) - even_attention_mask = np.array([1] * 1030) - - def _preprocess_onnx_input( - self, onnx_input: Dict[str, np.ndarray], **kwargs - ) -> Dict[str, np.ndarray]: - onnx_input["input_ids"] = np.array( - [self.empty_text_placeholder for _ in onnx_input["input_ids"]] - ) - onnx_input["attention_mask"] = np.array( - [self.even_attention_mask for _ in onnx_input["input_ids"]] - ) - return onnx_input - - @classmethod - def list_supported_models(cls) -> List[Dict[str, Any]]: - """ - Lists the supported models. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. - """ - return supported_onnx_models - - def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: - return output.model_output.astype(np.float32) - - def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: - with contextlib.ExitStack(): - image_files = [ - Image.open(image) if not isinstance(image, Image.Image) else image - for image in images - ] - encoded = self.processor(image_files) - onnx_input = self._build_onnx_input(encoded) - onnx_input = self._preprocess_onnx_input(onnx_input) - - model_output = self.model.run(None, onnx_input) - embeddings = model_output[0].reshape(len(images), *supported_onnx_models[0]["dim"]) - return OnnxOutputContext(model_output=embeddings) diff --git a/fastembed/image/image_embedding.py b/fastembed/image/image_embedding.py index f481f486..aa4c91b4 100644 --- a/fastembed/image/image_embedding.py +++ b/fastembed/image/image_embedding.py @@ -5,11 +5,10 @@ from fastembed.common import ImageInput, OnnxProvider from fastembed.image.image_embedding_base import ImageEmbeddingBase from fastembed.image.onnx_embedding import OnnxImageEmbedding -from fastembed.image.colpali_model import ColpaliImageModel class ImageEmbedding(ImageEmbeddingBase): - EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding, ColpaliImageModel] + EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding] @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: diff --git a/fastembed/text/colpali_model.py b/fastembed/text/colpali_model.py deleted file mode 100644 index 0320922c..00000000 --- a/fastembed/text/colpali_model.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Any, Dict, Iterable, List - -import numpy as np - -from fastembed.common.onnx_model import OnnxOutputContext -from fastembed.text.onnx_embedding import OnnxTextEmbedding - -supported_onnx_models = [ - { - "model": "akshayballal/colpali-v1.2-merged", - "dim": (16, 128), - "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", - "license": "mit", - "size_in_GB": 6.08, - "sources": { - "hf": "akshayballal/colpali-v1.2-merged-onnx", - }, - "additional_files": [ - "model.onnx_data", - "tokenizer.json", - "tokenizer_config.json", - "config.json", - ], - "model_file": "model.onnx", - } -] - - -class ColpaliTextModel(OnnxTextEmbedding): - query_prefix = "Query: " - bos_token = "" - pad_token = "" - query_tokens = [2, 9413] - image_placeholder_size = (3, 448, 448) - - def _preprocess_onnx_input( - self, onnx_input: Dict[str, np.ndarray], **kwargs - ) -> Dict[str, np.ndarray]: - empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) - onnx_input["pixel_values"] = np.array( - [empty_image_placeholder for _ in onnx_input["input_ids"]] - ) - onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) - return onnx_input - - @classmethod - def list_supported_models(cls) -> List[Dict[str, Any]]: - """ - Lists the supported models. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing the model information. - """ - return supported_onnx_models - - def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]: - return output.model_output.astype(np.float32) - - def _preprocess_queries(self, documents: List[str]): - texts_query: List[str] = [] - - for query in documents: - query = self.bos_token + self.query_prefix + query + self.pad_token * 10 - query += "\n" - - texts_query.append(query) - return texts_query - - def onnx_embed( - self, - documents: List[str], - **kwargs, - ) -> OnnxOutputContext: - documents = self._preprocess_queries(documents) - self.tokenizer.enable_truncation(max_length=10000) - encoded = self.tokenize(documents, **kwargs) - input_ids = np.array([self.query_tokens + e.ids[2:] for e in encoded]) - - attention_mask = np.array([e.attention_mask for e in encoded]) - onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} - onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) - onnx_input["attention_mask"] = attention_mask - model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) - return OnnxOutputContext( - model_output=model_output[0], - attention_mask=onnx_input.get("attention_mask", attention_mask), - input_ids=onnx_input.get("input_ids", input_ids), - ) From c040120ca855c2c5fc64b3eabe2a00073dfc2af8 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Thu, 28 Nov 2024 12:17:26 +0100 Subject: [PATCH 15/27] Tests draft --- .../late_interaction_image_embedding.py | 2 +- .../test_late_interaction_image_embeddings.py | 160 ++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 tests/test_late_interaction_image_embeddings.py diff --git a/fastembed/late_interaction/late_interaction_image_embedding.py b/fastembed/late_interaction/late_interaction_image_embedding.py index fffddbd0..748d861e 100644 --- a/fastembed/late_interaction/late_interaction_image_embedding.py +++ b/fastembed/late_interaction/late_interaction_image_embedding.py @@ -9,7 +9,7 @@ ) -class LateInteractionTextEmbedding(LateInteractionImageEmbeddingBase): +class LateInteractionImageEmbedding(LateInteractionImageEmbeddingBase): EMBEDDINGS_REGISTRY: list[Type[LateInteractionImageEmbeddingBase]] = [ColPali] @classmethod diff --git a/tests/test_late_interaction_image_embeddings.py b/tests/test_late_interaction_image_embeddings.py new file mode 100644 index 00000000..1da8f6c6 --- /dev/null +++ b/tests/test_late_interaction_image_embeddings.py @@ -0,0 +1,160 @@ +import os + +import numpy as np +import pytest + +from fastembed.late_interaction.late_interaction_image_embedding import ( + LateInteractionImageEmbedding, +) +from tests.utils import delete_model_cache +from tests.config import TEST_MISC_DIR +from PIL import Image + +# vectors are abridged and rounded for brevity +CANONICAL_COLUMN_VALUES = { + "akshayballal/colpali-v1.2-merged": np.array( + [ + [ + [0.015, 0.051, 0.059, 0.026, -0.061, -0.027, -0.014], + [-0.22, -0.111, 0.046, 0.081, -0.048, -0.052, -0.086], + [-0.184, -0.131, 0.004, 0.062, -0.038, -0.059, -0.127], + [-0.209, -0.113, 0.015, 0.059, -0.035, -0.035, -0.072], + [-0.031, -0.044, 0.092, -0.005, 0.006, -0.057, -0.061], + [-0.18, -0.039, 0.031, 0.003, 0.083, -0.041, 0.088], + [-0.091, 0.023, 0.116, -0.02, 0.039, -0.064, -0.026], + ], + [ + [-0.25, -0.112, -0.065, -0.014, 0.005, -0.092, 0.024], + [-0.22, -0.096, -0.014, 0.039, -0.02, -0.12, -0.004], + [-0.228, -0.114, 0.031, 0.019, 0.034, -0.052, -0.031], + [-0.274, -0.186, 0.095, -0.019, 0.017, 0.021, -0.016], + [-0.186, -0.061, -0.01, 0.065, -0.058, -0.05, 0.019], + [-0.183, -0.11, -0.034, -0.042, 0.026, -0.071, 0.02], + [-0.153, -0.072, -0.015, 0.088, -0.081, -0.043, 0.04], + ], + ] + ), +} + +CANONICAL_QUERY_VALUES = { + "akshayballal/colpali-v1.2-merged": np.array( + [ + [0.158, -0.02, 0.1, -0.023, 0.045, 0.031, 0.071], + [-0.074, -0.111, 0.065, -0.0, -0.089, -0.003, -0.099], + [-0.034, -0.014, 0.174, -0.063, -0.09, -0.036, 0.064], + [-0.07, -0.014, 0.186, -0.013, -0.021, -0.062, 0.107], + [-0.085, 0.025, 0.179, -0.101, 0.036, -0.089, 0.098], + [-0.058, 0.031, 0.18, -0.078, 0.023, -0.119, 0.131], + [-0.067, 0.038, 0.188, -0.079, -0.001, -0.123, 0.127], + [-0.063, 0.037, 0.204, -0.069, 0.003, -0.118, 0.134], + [-0.054, 0.036, 0.212, -0.072, -0.001, -0.117, 0.133], + [-0.044, 0.03, 0.218, -0.077, -0.003, -0.107, 0.139], + [-0.037, 0.033, 0.22, -0.088, 0.0, -0.095, 0.146], + [-0.031, 0.041, 0.213, -0.092, 0.001, -0.088, 0.147], + [-0.026, 0.047, 0.204, -0.089, -0.002, -0.084, 0.144], + [-0.027, 0.051, 0.199, -0.084, -0.007, -0.083, 0.14], + [-0.031, 0.056, 0.19, -0.082, -0.011, -0.086, 0.135], + [-0.008, 0.108, 0.144, -0.095, -0.018, -0.086, 0.085], + ] + ), +} + +queries = ["hello world", "flag embedding"] +images = [ + TEST_MISC_DIR / "image.jpeg", + str(TEST_MISC_DIR / "small_image.jpeg"), + Image.open((TEST_MISC_DIR / "small_image.jpeg")), +] + + +def test_batch_embedding(): + is_ci = os.getenv("CI") + docs_to_embed = images * 10 + + for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionImageEmbedding(model_name=model_name) + result = list(model.embed(docs_to_embed, batch_size=6)) + + for value in result: + token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding(): + is_ci = os.getenv("CI") + docs_to_embed = images + + for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionImageEmbedding(model_name=model_name) + result = next(iter(model.embed(docs_to_embed, batch_size=6))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_query(): + is_ci = os.getenv("CI") + queries_to_embed = queries + + for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionImageEmbedding(model_name=model_name) + result = next(iter(model.query_embed(queries_to_embed))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_parallel_processing(): + is_ci = os.getenv("CI") + model = LateInteractionImageEmbedding(model_name="colbert-ir/colbertv2.0") + token_dim = 128 + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) + + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == token_dim + assert np.allclose(embeddings, embeddings_2, atol=1e-3) + assert np.allclose(embeddings, embeddings_3, atol=1e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize( + "model_name", + ["colbert-ir/colbertv2.0"], +) +def test_lazy_load(model_name): + is_ci = os.getenv("CI") + + model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) + assert not hasattr(model.model, "model") + + docs = ["hello world", "flag embedding"] + list(model.embed(docs)) + assert hasattr(model.model, "model") + + model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) + list(model.query_embed(docs)) + + model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) + list(model.passage_embed(docs)) + + if is_ci: + delete_model_cache(model.model._model_dir) From 367178b904092707f562a823381a82f3c22a7ad5 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 11:44:34 +0100 Subject: [PATCH 16/27] Tests draft --- experiments/colpali_convert_lang_model.py | 31 ++ experiments/colpali_image_test.ipynb | 381 ++++++++++++++++++ experiments/colpali_text_test.ipynb | 444 +++++++++++++++++++++ experiments/late_interaction_colpali.ipynb | 202 ++++++++++ fastembed/common/model_management.py | 2 - fastembed/late_interaction/colpali.py | 6 +- 6 files changed, 1062 insertions(+), 4 deletions(-) create mode 100644 experiments/colpali_convert_lang_model.py create mode 100644 experiments/colpali_image_test.ipynb create mode 100644 experiments/colpali_text_test.ipynb create mode 100644 experiments/late_interaction_colpali.ipynb diff --git a/experiments/colpali_convert_lang_model.py b/experiments/colpali_convert_lang_model.py new file mode 100644 index 00000000..691f0cba --- /dev/null +++ b/experiments/colpali_convert_lang_model.py @@ -0,0 +1,31 @@ +import torch +from colpali_engine.models import ColPali, ColPaliProcessor +import onnxruntime as ort + +model_name = "vidore/colpali-v1.2" +original_model = ColPali.from_pretrained(model_name).eval() +processor = ColPaliProcessor.from_pretrained(model_name) + +dummy_query = ["Is attention really all you need?"] + +# Process the input query +processed_query = processor.process_queries(dummy_query).to(original_model.device) + +# Prepare input tensors +input_query_tensor = processed_query["input_ids"].type(torch.long) +attention_mask_tensor = processed_query["attention_mask"].type(torch.long) + +# Export the model to ONNX with the required inputs and dynamic shapes +torch.onnx.export( + original_model.model.language_model, + (input_query_tensor, attention_mask_tensor), + "experiments/colpali_text_encoder_dir/model.onnx", + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamo=True, + opset_version=14, +) + + +image_session = ort.InferenceSession("experiments/colpali_text_encoder_dir/model.onnx") +print("Session output", image_session((input_query_tensor, attention_mask_tensor))) diff --git a/experiments/colpali_image_test.ipynb b/experiments/colpali_image_test.ipynb new file mode 100644 index 00000000..354bc596 --- /dev/null +++ b/experiments/colpali_image_test.ipynb @@ -0,0 +1,381 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:02:39.315496Z", + "start_time": "2024-11-28T10:02:39.290846Z" + }, + "collapsed": true + }, + "outputs": [], + "source": [ + "from PIL import Image\n", + "\n", + "images = [\n", + " Image.open(\"/Users/d.rudenko/PycharmProjects/opensource/fastembed/tests/misc/image.jpeg\"),\n", + " Image.open(\n", + " \"/Users/d.rudenko/PycharmProjects/opensource/fastembed/tests/misc/small_image.jpeg\"\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e46189ce4b8b0677", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T09:58:37.254586Z", + "start_time": "2024-11-28T09:58:22.754066Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ba9856c5109643049718592a236b2206", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 7 files: 0%| | 0/7 [00:00` tokens in the very beginning of your text and `` token after that. For this call, we will infer how many images each text has and add special tokens.\n" + ] + } + ], + "source": [ + "from colpali_engine.models import ColPaliProcessor\n", + "\n", + "model_name = \"vidore/colpali-v1.2-merged\"\n", + "\n", + "processor = ColPaliProcessor.from_pretrained(model_name)\n", + "# Process the inputs\n", + "batch_images_onnx = processor.process_images(images)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "89c2fbe3d64964fc", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:03:56.766986Z", + "start_time": "2024-11-28T10:02:43.893495Z" + } + }, + "outputs": [], + "source": [ + "import onnxruntime as ort\n", + "\n", + "sess = ort.InferenceSession(\"/Users/d.rudenko/dev/qdrant/colpali-v1.2-merged-onnx/model.onnx\")\n", + "image_embeddings_onnx = sess.run(\n", + " [sess.get_outputs()[0].name],\n", + " {\n", + " \"input_ids\": batch_images_onnx[\"input_ids\"].numpy(),\n", + " \"pixel_values\": batch_images_onnx[\"pixel_values\"].numpy(),\n", + " \"attention_mask\": batch_images_onnx[\"attention_mask\"].numpy(),\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "61b43dd6caaa0909", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:06:23.238770Z", + "start_time": "2024-11-28T10:06:23.235457Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 2, 1030, 128)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.array(image_embeddings_onnx).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5be8ebb15c6dfaa6", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:59:48.765049Z", + "start_time": "2024-11-28T10:59:48.761122Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[ 0.015 0.051 0.059 0.026 -0.061 -0.027 -0.014]\n", + " [-0.22 -0.111 0.046 0.081 -0.048 -0.052 -0.086]\n", + " [-0.184 -0.131 0.004 0.062 -0.038 -0.059 -0.127]\n", + " [-0.209 -0.113 0.015 0.059 -0.035 -0.035 -0.072]\n", + " [-0.031 -0.044 0.092 -0.005 0.006 -0.057 -0.061]\n", + " [-0.18 -0.039 0.031 0.003 0.083 -0.041 0.088]\n", + " [-0.091 0.023 0.116 -0.02 0.039 -0.064 -0.026]]\n", + "\n", + " [[-0.25 -0.112 -0.065 -0.014 0.005 -0.092 0.024]\n", + " [-0.22 -0.096 -0.014 0.039 -0.02 -0.12 -0.004]\n", + " [-0.228 -0.114 0.031 0.019 0.034 -0.052 -0.031]\n", + " [-0.274 -0.186 0.095 -0.019 0.017 0.021 -0.016]\n", + " [-0.186 -0.061 -0.01 0.065 -0.058 -0.05 0.019]\n", + " [-0.183 -0.11 -0.034 -0.042 0.026 -0.071 0.02 ]\n", + " [-0.153 -0.072 -0.015 0.088 -0.081 -0.043 0.04 ]]]\n" + ] + } + ], + "source": [ + "print(np.round(image_embeddings_onnx[0][:, :7, :7], decimals=3))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bc9f7ffda971d3ba", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:59:02.286294Z", + "start_time": "2024-11-28T10:59:02.264997Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.01533 , 0.05118 , 0.05948 , 0.02583 , -0.06128 , -0.02682 ,\n", + " -0.013565, 0.10254 , -0.0983 , 0.1109 , -0.00342 , -0.0344 ,\n", + " -0.00887 , -0.1616 , 0.09814 , 0.2257 , 0.03976 , 0.03687 ,\n", + " 0.1648 , 0.06866 , 0.0396 , 0.1672 , 0.1455 , -0.1387 ,\n", + " 0.1203 , 0.04907 , -0.07965 , -0.0885 , 0.01982 , 0.0404 ,\n", + " -0.07513 , -0.02844 , 0.04337 , 0.03857 , -0.1065 , 0.0288 ,\n", + " -0.1279 , -0.1126 , 0.03363 , -0.0507 , 0.11584 , 0.0483 ,\n", + " 0.035 , -0.08417 , -0.0907 , 0.0279 , 0.1394 , -0.10364 ,\n", + " -0.1471 , -0.07135 , -0.136 , 0.1289 , 0.082 , 0.02232 ,\n", + " -0.00571 , -0.02547 , 0.1053 , 0.0377 , 0.0148 , 0.02795 ,\n", + " -0.01859 , -0.11066 , -0.12195 , 0.0583 , 0.0995 , 0.01086 ,\n", + " 0.0859 , 0.1302 , -0.10126 , 0.005417, 0.05423 , -0.1808 ,\n", + " 0.1444 , 0.1885 , 0.09247 , -0.04718 , 0.1018 , -0.02997 ,\n", + " -0.0598 , -0.011284, 0.1203 , -0.1313 , -0.04584 , -0.02725 ,\n", + " -0.1277 , -0.04236 , -0.08466 , -0.0861 , 0.1131 , 0.02806 ,\n", + " -0.0947 , 0.04388 , 0.04263 , 0.03598 , -0.06866 , -0.06018 ,\n", + " -0.02763 , -0.0972 , 0.11505 , -0.1097 , -0.04166 , 0.0742 ,\n", + " -0.06683 , -0.02188 , -0.1663 , -0.0902 , 0.02594 , -0.03802 ,\n", + " -0.034 , -0.04828 , -0.05765 , 0.0633 , -0.02515 , -0.08826 ,\n", + " -0.09753 , -0.10974 , -0.074 , -0.02083 , -0.1301 , 0.1383 ,\n", + " 0.1428 , 0.0935 , 0.0949 , 0.03876 , 0.08514 , -0.12256 ,\n", + " -0.0451 , -0.002306], dtype=float16)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(image_embeddings_onnx)[0][0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "34a238b20e5fcab2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-20T15:39:41.314176Z", + "start_time": "2024-11-20T15:39:41.308579Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.allclose(image_embeddings_onnx[0][0], fastembed_i_embeddings[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5ca3b11eb3813a87", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-20T15:39:42.081408Z", + "start_time": "2024-11-20T15:39:42.078582Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.01533 , 0.05118 , 0.05948 , 0.02583 , -0.06128 , -0.02682 ,\n", + " -0.013565, 0.10254 , -0.0983 , 0.1109 , -0.00342 , -0.0344 ,\n", + " -0.00887 , -0.1616 , 0.09814 , 0.2257 , 0.03976 , 0.03687 ,\n", + " 0.1648 , 0.06866 , 0.0396 , 0.1672 , 0.1455 , -0.1387 ,\n", + " 0.1203 , 0.04907 , -0.07965 , -0.0885 , 0.01982 , 0.0404 ,\n", + " -0.07513 , -0.02844 , 0.04337 , 0.03857 , -0.1065 , 0.0288 ,\n", + " -0.1279 , -0.1126 , 0.03363 , -0.0507 , 0.11584 , 0.0483 ,\n", + " 0.035 , -0.08417 , -0.0907 , 0.0279 , 0.1394 , -0.10364 ,\n", + " -0.1471 , -0.07135 , -0.136 , 0.1289 , 0.082 , 0.02232 ,\n", + " -0.00571 , -0.02547 , 0.1053 , 0.0377 , 0.0148 , 0.02795 ,\n", + " -0.01859 , -0.11066 , -0.12195 , 0.0583 , 0.0995 , 0.01086 ,\n", + " 0.0859 , 0.1302 , -0.10126 , 0.005417, 0.05423 , -0.1808 ,\n", + " 0.1444 , 0.1885 , 0.09247 , -0.04718 , 0.1018 , -0.02997 ,\n", + " -0.0598 , -0.011284, 0.1203 , -0.1313 , -0.04584 , -0.02725 ,\n", + " -0.1277 , -0.04236 , -0.08466 , -0.0861 , 0.1131 , 0.02806 ,\n", + " -0.0947 , 0.04388 , 0.04263 , 0.03598 , -0.06866 , -0.06018 ,\n", + " -0.02763 , -0.0972 , 0.11505 , -0.1097 , -0.04166 , 0.0742 ,\n", + " -0.06683 , -0.02188 , -0.1663 , -0.0902 , 0.02594 , -0.03802 ,\n", + " -0.034 , -0.04828 , -0.05765 , 0.0633 , -0.02515 , -0.08826 ,\n", + " -0.09753 , -0.10974 , -0.074 , -0.02083 , -0.1301 , 0.1383 ,\n", + " 0.1428 , 0.0935 , 0.0949 , 0.03876 , 0.08514 , -0.12256 ,\n", + " -0.0451 , -0.002306], dtype=float16)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image_embeddings_onnx[0][0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2c52a4d7d83aeda7", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-20T16:05:32.115768Z", + "start_time": "2024-11-20T16:05:32.090218Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[ 0.01532745, 0.05117798, 0.05947876, ..., -0.12255859,\n", + " -0.04510498, -0.00230598],\n", + " [-0.22009277, -0.11071777, 0.04562378, ..., 0.00257111,\n", + " -0.06988525, 0.12384033],\n", + " [-0.18371582, -0.13085938, 0.00393677, ..., -0.02949524,\n", + " -0.05444336, 0.1295166 ],\n", + " ...,\n", + " [-0.1418457 , 0.01023102, 0.1239624 , ..., -0.00460434,\n", + " 0.17321777, 0.09454346],\n", + " [-0.24572754, -0.06878662, 0.11834717, ..., -0.02763367,\n", + " -0.03022766, 0.08917236],\n", + " [-0.2211914 , -0.04171753, 0.19519043, ..., -0.01535797,\n", + " -0.02432251, -0.03561401]], dtype=float32),\n", + " array([[-2.49877930e-01, -1.11511230e-01, -6.51855469e-02, ...,\n", + " 3.19519043e-02, 3.44543457e-02, -1.33666992e-02],\n", + " [-2.20336914e-01, -9.56420898e-02, -1.39694214e-02, ...,\n", + " -8.88705254e-05, -1.57318115e-02, -1.00555420e-02],\n", + " [-2.28271484e-01, -1.14501953e-01, 3.10058594e-02, ...,\n", + " 7.59277344e-02, -4.28466797e-02, 1.19262695e-01],\n", + " ...,\n", + " [-2.04589844e-01, -4.86755371e-02, 8.46557617e-02, ...,\n", + " -3.98254395e-02, 1.66625977e-01, 9.71679688e-02],\n", + " [-2.88085938e-01, -4.50439453e-02, 7.69653320e-02, ...,\n", + " -4.36096191e-02, -1.28784180e-02, 6.26220703e-02],\n", + " [-2.67578125e-01, -3.25317383e-02, 1.66625977e-01, ...,\n", + " -2.90679932e-03, -1.52282715e-02, -3.62243652e-02]], dtype=float32)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fastembed_i_embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "786bfac25eb7704a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/experiments/colpali_text_test.ipynb b/experiments/colpali_text_test.ipynb new file mode 100644 index 00000000..b3392089 --- /dev/null +++ b/experiments/colpali_text_test.ipynb @@ -0,0 +1,444 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "54b3bfd4ad5b9ee6", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:44:32.841758Z", + "start_time": "2024-11-28T10:44:32.830025Z" + } + }, + "outputs": [], + "source": [ + "# Your inputs\n", + "queries = [\n", + " # \"Is attention really all you need?\",\n", + " # \"Are Benjamin, Antoine, Merve, and Jo best friends?\",\n", + " # \"Long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long\"\n", + " \"hello world\",\n", + " \"flag embedding\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "86ee1b68fb88b11d", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-27T22:54:23.016952Z", + "start_time": "2024-11-27T22:33:14.872976Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f1b463d5ae47404f951fecc6629e8008", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 7 files: 0%| | 0/7 [00:00 None: """ @@ -223,6 +224,7 @@ def load_onnx_model(self) -> None: cuda=self.cuda, device_id=self.device_id, ) + self.tokenizer.enable_truncation(max_length=maxsize) class ColPaliEmbeddingWorker(TextEmbeddingWorker): From 1fff39b0be733f2783dab3212fa2fa9334f88ec9 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 11:45:18 +0100 Subject: [PATCH 17/27] Tests draft --- experiments/colpali_convert_lang_model.py | 31 ----------------------- 1 file changed, 31 deletions(-) delete mode 100644 experiments/colpali_convert_lang_model.py diff --git a/experiments/colpali_convert_lang_model.py b/experiments/colpali_convert_lang_model.py deleted file mode 100644 index 691f0cba..00000000 --- a/experiments/colpali_convert_lang_model.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from colpali_engine.models import ColPali, ColPaliProcessor -import onnxruntime as ort - -model_name = "vidore/colpali-v1.2" -original_model = ColPali.from_pretrained(model_name).eval() -processor = ColPaliProcessor.from_pretrained(model_name) - -dummy_query = ["Is attention really all you need?"] - -# Process the input query -processed_query = processor.process_queries(dummy_query).to(original_model.device) - -# Prepare input tensors -input_query_tensor = processed_query["input_ids"].type(torch.long) -attention_mask_tensor = processed_query["attention_mask"].type(torch.long) - -# Export the model to ONNX with the required inputs and dynamic shapes -torch.onnx.export( - original_model.model.language_model, - (input_query_tensor, attention_mask_tensor), - "experiments/colpali_text_encoder_dir/model.onnx", - input_names=["input_ids", "attention_mask"], - output_names=["logits"], - dynamo=True, - opset_version=14, -) - - -image_session = ort.InferenceSession("experiments/colpali_text_encoder_dir/model.onnx") -print("Session output", image_session((input_query_tensor, attention_mask_tensor))) From 667eee148d743d70a29ee04660e40f47a4554795 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 16:02:33 +0100 Subject: [PATCH 18/27] Tests another draft --- fastembed/late_interaction/colpali.py | 6 +++- .../test_late_interaction_image_embeddings.py | 30 +++++++------------ 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/fastembed/late_interaction/colpali.py b/fastembed/late_interaction/colpali.py index eee1c3a7..3f557ac8 100644 --- a/fastembed/late_interaction/colpali.py +++ b/fastembed/late_interaction/colpali.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional, Sequence, Union, List, Dict +from typing import Any, Iterable, Optional, Sequence, Union, List, Dict, Type import numpy as np from sys import maxsize @@ -226,6 +226,10 @@ def load_onnx_model(self) -> None: ) self.tokenizer.enable_truncation(max_length=maxsize) + @classmethod + def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: + return ColPaliEmbeddingWorker + class ColPaliEmbeddingWorker(TextEmbeddingWorker): def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: diff --git a/tests/test_late_interaction_image_embeddings.py b/tests/test_late_interaction_image_embeddings.py index 1da8f6c6..be9c56ca 100644 --- a/tests/test_late_interaction_image_embeddings.py +++ b/tests/test_late_interaction_image_embeddings.py @@ -22,16 +22,7 @@ [-0.031, -0.044, 0.092, -0.005, 0.006, -0.057, -0.061], [-0.18, -0.039, 0.031, 0.003, 0.083, -0.041, 0.088], [-0.091, 0.023, 0.116, -0.02, 0.039, -0.064, -0.026], - ], - [ - [-0.25, -0.112, -0.065, -0.014, 0.005, -0.092, 0.024], - [-0.22, -0.096, -0.014, 0.039, -0.02, -0.12, -0.004], - [-0.228, -0.114, 0.031, 0.019, 0.034, -0.052, -0.031], - [-0.274, -0.186, 0.095, -0.019, 0.017, 0.021, -0.016], - [-0.186, -0.061, -0.01, 0.065, -0.058, -0.05, 0.019], - [-0.183, -0.11, -0.034, -0.042, 0.026, -0.071, 0.02], - [-0.153, -0.072, -0.015, 0.088, -0.081, -0.043, 0.04], - ], + ] ] ), } @@ -69,7 +60,7 @@ def test_batch_embedding(): is_ci = os.getenv("CI") - docs_to_embed = images * 10 + docs_to_embed = images for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): print("evaluating", model_name) @@ -77,8 +68,9 @@ def test_batch_embedding(): result = list(model.embed(docs_to_embed, batch_size=6)) for value in result: - token_num, abridged_dim = expected_result.shape - assert np.allclose(value[:, :abridged_dim], expected_result, atol=2e-3) + batch_size, token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=1e-3) + break if is_ci: delete_model_cache(model.model._model_dir) @@ -92,8 +84,8 @@ def test_single_embedding(): print("evaluating", model_name) model = LateInteractionImageEmbedding(model_name=model_name) result = next(iter(model.embed(docs_to_embed, batch_size=6))) - token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) + batch_size, token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) if is_ci: delete_model_cache(model.model._model_dir) @@ -108,7 +100,7 @@ def test_single_embedding_query(): model = LateInteractionImageEmbedding(model_name=model_name) result = next(iter(model.query_embed(queries_to_embed))) token_num, abridged_dim = expected_result.shape - assert np.allclose(result[:, :abridged_dim], expected_result, atol=2e-3) + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) if is_ci: delete_model_cache(model.model._model_dir) @@ -116,7 +108,7 @@ def test_single_embedding_query(): def test_parallel_processing(): is_ci = os.getenv("CI") - model = LateInteractionImageEmbedding(model_name="colbert-ir/colbertv2.0") + model = LateInteractionImageEmbedding(model_name="akshayballal/colpali-v1.2-merged") token_dim = 128 docs = ["hello world", "flag embedding"] * 100 embeddings = list(model.embed(docs, batch_size=10, parallel=2)) @@ -138,7 +130,7 @@ def test_parallel_processing(): @pytest.mark.parametrize( "model_name", - ["colbert-ir/colbertv2.0"], + ["akshayballal/colpali-v1.2-merged"], ) def test_lazy_load(model_name): is_ci = os.getenv("CI") @@ -154,7 +146,7 @@ def test_lazy_load(model_name): list(model.query_embed(docs)) model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) - list(model.passage_embed(docs)) + list(model.embed(docs)) if is_ci: delete_model_cache(model.model._model_dir) From a9bddf38b0a3cb7d47e9b06a9e351bcaa39745a3 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 17:42:06 +0100 Subject: [PATCH 19/27] Reduce batch size for test --- tests/test_late_interaction_image_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_late_interaction_image_embeddings.py b/tests/test_late_interaction_image_embeddings.py index be9c56ca..d29294d2 100644 --- a/tests/test_late_interaction_image_embeddings.py +++ b/tests/test_late_interaction_image_embeddings.py @@ -65,7 +65,7 @@ def test_batch_embedding(): for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): print("evaluating", model_name) model = LateInteractionImageEmbedding(model_name=model_name) - result = list(model.embed(docs_to_embed, batch_size=6)) + result = list(model.embed(docs_to_embed, batch_size=2)) for value in result: batch_size, token_num, abridged_dim = expected_result.shape From e747c34eccb4265b4a8035b891905118469bd1b2 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 18:37:43 +0100 Subject: [PATCH 20/27] CI pre-clean up --- tests/test_late_interaction_image_embeddings.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_late_interaction_image_embeddings.py b/tests/test_late_interaction_image_embeddings.py index d29294d2..a467d7eb 100644 --- a/tests/test_late_interaction_image_embeddings.py +++ b/tests/test_late_interaction_image_embeddings.py @@ -60,6 +60,9 @@ def test_batch_embedding(): is_ci = os.getenv("CI") + if is_ci: + delete_model_cache(model.model._model_dir) + docs_to_embed = images for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): @@ -78,6 +81,9 @@ def test_batch_embedding(): def test_single_embedding(): is_ci = os.getenv("CI") + if is_ci: + delete_model_cache(model.model._model_dir) + docs_to_embed = images for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): @@ -93,6 +99,9 @@ def test_single_embedding(): def test_single_embedding_query(): is_ci = os.getenv("CI") + if is_ci: + delete_model_cache(model.model._model_dir) + queries_to_embed = queries for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): @@ -108,6 +117,9 @@ def test_single_embedding_query(): def test_parallel_processing(): is_ci = os.getenv("CI") + if is_ci: + delete_model_cache(model.model._model_dir) + model = LateInteractionImageEmbedding(model_name="akshayballal/colpali-v1.2-merged") token_dim = 128 docs = ["hello world", "flag embedding"] * 100 @@ -134,6 +146,8 @@ def test_parallel_processing(): ) def test_lazy_load(model_name): is_ci = os.getenv("CI") + if is_ci: + delete_model_cache(model.model._model_dir) model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") From 0ff8f49fe3c5bfd1511c64f54950f9ceb83e7086 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 18:40:10 +0100 Subject: [PATCH 21/27] Clean_up was a lie --- tests/test_late_interaction_image_embeddings.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/test_late_interaction_image_embeddings.py b/tests/test_late_interaction_image_embeddings.py index a467d7eb..7d048e36 100644 --- a/tests/test_late_interaction_image_embeddings.py +++ b/tests/test_late_interaction_image_embeddings.py @@ -60,9 +60,6 @@ def test_batch_embedding(): is_ci = os.getenv("CI") - if is_ci: - delete_model_cache(model.model._model_dir) - docs_to_embed = images for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): @@ -81,8 +78,6 @@ def test_batch_embedding(): def test_single_embedding(): is_ci = os.getenv("CI") - if is_ci: - delete_model_cache(model.model._model_dir) docs_to_embed = images @@ -99,8 +94,6 @@ def test_single_embedding(): def test_single_embedding_query(): is_ci = os.getenv("CI") - if is_ci: - delete_model_cache(model.model._model_dir) queries_to_embed = queries @@ -117,10 +110,9 @@ def test_single_embedding_query(): def test_parallel_processing(): is_ci = os.getenv("CI") - if is_ci: - delete_model_cache(model.model._model_dir) model = LateInteractionImageEmbedding(model_name="akshayballal/colpali-v1.2-merged") + token_dim = 128 docs = ["hello world", "flag embedding"] * 100 embeddings = list(model.embed(docs, batch_size=10, parallel=2)) @@ -146,8 +138,6 @@ def test_parallel_processing(): ) def test_lazy_load(model_name): is_ci = os.getenv("CI") - if is_ci: - delete_model_cache(model.model._model_dir) model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") From 274e0d799185d3aed1e24a2079bad5636d403f3d Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 19:33:50 +0100 Subject: [PATCH 22/27] Clean up non-needed notebooks --- experiments/colpali_image_test.ipynb | 381 ------------------ experiments/colpali_text_test.ipynb | 444 --------------------- experiments/late_interaction_colpali.ipynb | 202 ---------- 3 files changed, 1027 deletions(-) delete mode 100644 experiments/colpali_image_test.ipynb delete mode 100644 experiments/colpali_text_test.ipynb delete mode 100644 experiments/late_interaction_colpali.ipynb diff --git a/experiments/colpali_image_test.ipynb b/experiments/colpali_image_test.ipynb deleted file mode 100644 index 354bc596..00000000 --- a/experiments/colpali_image_test.ipynb +++ /dev/null @@ -1,381 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "id": "initial_id", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T10:02:39.315496Z", - "start_time": "2024-11-28T10:02:39.290846Z" - }, - "collapsed": true - }, - "outputs": [], - "source": [ - "from PIL import Image\n", - "\n", - "images = [\n", - " Image.open(\"/Users/d.rudenko/PycharmProjects/opensource/fastembed/tests/misc/image.jpeg\"),\n", - " Image.open(\n", - " \"/Users/d.rudenko/PycharmProjects/opensource/fastembed/tests/misc/small_image.jpeg\"\n", - " ),\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e46189ce4b8b0677", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T09:58:37.254586Z", - "start_time": "2024-11-28T09:58:22.754066Z" - } - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ba9856c5109643049718592a236b2206", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 7 files: 0%| | 0/7 [00:00` tokens in the very beginning of your text and `` token after that. For this call, we will infer how many images each text has and add special tokens.\n" - ] - } - ], - "source": [ - "from colpali_engine.models import ColPaliProcessor\n", - "\n", - "model_name = \"vidore/colpali-v1.2-merged\"\n", - "\n", - "processor = ColPaliProcessor.from_pretrained(model_name)\n", - "# Process the inputs\n", - "batch_images_onnx = processor.process_images(images)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "89c2fbe3d64964fc", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T10:03:56.766986Z", - "start_time": "2024-11-28T10:02:43.893495Z" - } - }, - "outputs": [], - "source": [ - "import onnxruntime as ort\n", - "\n", - "sess = ort.InferenceSession(\"/Users/d.rudenko/dev/qdrant/colpali-v1.2-merged-onnx/model.onnx\")\n", - "image_embeddings_onnx = sess.run(\n", - " [sess.get_outputs()[0].name],\n", - " {\n", - " \"input_ids\": batch_images_onnx[\"input_ids\"].numpy(),\n", - " \"pixel_values\": batch_images_onnx[\"pixel_values\"].numpy(),\n", - " \"attention_mask\": batch_images_onnx[\"attention_mask\"].numpy(),\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "61b43dd6caaa0909", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T10:06:23.238770Z", - "start_time": "2024-11-28T10:06:23.235457Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(1, 2, 1030, 128)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "\n", - "np.array(image_embeddings_onnx).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "5be8ebb15c6dfaa6", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T10:59:48.765049Z", - "start_time": "2024-11-28T10:59:48.761122Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[[ 0.015 0.051 0.059 0.026 -0.061 -0.027 -0.014]\n", - " [-0.22 -0.111 0.046 0.081 -0.048 -0.052 -0.086]\n", - " [-0.184 -0.131 0.004 0.062 -0.038 -0.059 -0.127]\n", - " [-0.209 -0.113 0.015 0.059 -0.035 -0.035 -0.072]\n", - " [-0.031 -0.044 0.092 -0.005 0.006 -0.057 -0.061]\n", - " [-0.18 -0.039 0.031 0.003 0.083 -0.041 0.088]\n", - " [-0.091 0.023 0.116 -0.02 0.039 -0.064 -0.026]]\n", - "\n", - " [[-0.25 -0.112 -0.065 -0.014 0.005 -0.092 0.024]\n", - " [-0.22 -0.096 -0.014 0.039 -0.02 -0.12 -0.004]\n", - " [-0.228 -0.114 0.031 0.019 0.034 -0.052 -0.031]\n", - " [-0.274 -0.186 0.095 -0.019 0.017 0.021 -0.016]\n", - " [-0.186 -0.061 -0.01 0.065 -0.058 -0.05 0.019]\n", - " [-0.183 -0.11 -0.034 -0.042 0.026 -0.071 0.02 ]\n", - " [-0.153 -0.072 -0.015 0.088 -0.081 -0.043 0.04 ]]]\n" - ] - } - ], - "source": [ - "print(np.round(image_embeddings_onnx[0][:, :7, :7], decimals=3))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "bc9f7ffda971d3ba", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T10:59:02.286294Z", - "start_time": "2024-11-28T10:59:02.264997Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 0.01533 , 0.05118 , 0.05948 , 0.02583 , -0.06128 , -0.02682 ,\n", - " -0.013565, 0.10254 , -0.0983 , 0.1109 , -0.00342 , -0.0344 ,\n", - " -0.00887 , -0.1616 , 0.09814 , 0.2257 , 0.03976 , 0.03687 ,\n", - " 0.1648 , 0.06866 , 0.0396 , 0.1672 , 0.1455 , -0.1387 ,\n", - " 0.1203 , 0.04907 , -0.07965 , -0.0885 , 0.01982 , 0.0404 ,\n", - " -0.07513 , -0.02844 , 0.04337 , 0.03857 , -0.1065 , 0.0288 ,\n", - " -0.1279 , -0.1126 , 0.03363 , -0.0507 , 0.11584 , 0.0483 ,\n", - " 0.035 , -0.08417 , -0.0907 , 0.0279 , 0.1394 , -0.10364 ,\n", - " -0.1471 , -0.07135 , -0.136 , 0.1289 , 0.082 , 0.02232 ,\n", - " -0.00571 , -0.02547 , 0.1053 , 0.0377 , 0.0148 , 0.02795 ,\n", - " -0.01859 , -0.11066 , -0.12195 , 0.0583 , 0.0995 , 0.01086 ,\n", - " 0.0859 , 0.1302 , -0.10126 , 0.005417, 0.05423 , -0.1808 ,\n", - " 0.1444 , 0.1885 , 0.09247 , -0.04718 , 0.1018 , -0.02997 ,\n", - " -0.0598 , -0.011284, 0.1203 , -0.1313 , -0.04584 , -0.02725 ,\n", - " -0.1277 , -0.04236 , -0.08466 , -0.0861 , 0.1131 , 0.02806 ,\n", - " -0.0947 , 0.04388 , 0.04263 , 0.03598 , -0.06866 , -0.06018 ,\n", - " -0.02763 , -0.0972 , 0.11505 , -0.1097 , -0.04166 , 0.0742 ,\n", - " -0.06683 , -0.02188 , -0.1663 , -0.0902 , 0.02594 , -0.03802 ,\n", - " -0.034 , -0.04828 , -0.05765 , 0.0633 , -0.02515 , -0.08826 ,\n", - " -0.09753 , -0.10974 , -0.074 , -0.02083 , -0.1301 , 0.1383 ,\n", - " 0.1428 , 0.0935 , 0.0949 , 0.03876 , 0.08514 , -0.12256 ,\n", - " -0.0451 , -0.002306], dtype=float16)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.array(image_embeddings_onnx)[0][0][0]" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "34a238b20e5fcab2", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-20T15:39:41.314176Z", - "start_time": "2024-11-20T15:39:41.308579Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import numpy as np\n", - "\n", - "np.allclose(image_embeddings_onnx[0][0], fastembed_i_embeddings[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "5ca3b11eb3813a87", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-20T15:39:42.081408Z", - "start_time": "2024-11-20T15:39:42.078582Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 0.01533 , 0.05118 , 0.05948 , 0.02583 , -0.06128 , -0.02682 ,\n", - " -0.013565, 0.10254 , -0.0983 , 0.1109 , -0.00342 , -0.0344 ,\n", - " -0.00887 , -0.1616 , 0.09814 , 0.2257 , 0.03976 , 0.03687 ,\n", - " 0.1648 , 0.06866 , 0.0396 , 0.1672 , 0.1455 , -0.1387 ,\n", - " 0.1203 , 0.04907 , -0.07965 , -0.0885 , 0.01982 , 0.0404 ,\n", - " -0.07513 , -0.02844 , 0.04337 , 0.03857 , -0.1065 , 0.0288 ,\n", - " -0.1279 , -0.1126 , 0.03363 , -0.0507 , 0.11584 , 0.0483 ,\n", - " 0.035 , -0.08417 , -0.0907 , 0.0279 , 0.1394 , -0.10364 ,\n", - " -0.1471 , -0.07135 , -0.136 , 0.1289 , 0.082 , 0.02232 ,\n", - " -0.00571 , -0.02547 , 0.1053 , 0.0377 , 0.0148 , 0.02795 ,\n", - " -0.01859 , -0.11066 , -0.12195 , 0.0583 , 0.0995 , 0.01086 ,\n", - " 0.0859 , 0.1302 , -0.10126 , 0.005417, 0.05423 , -0.1808 ,\n", - " 0.1444 , 0.1885 , 0.09247 , -0.04718 , 0.1018 , -0.02997 ,\n", - " -0.0598 , -0.011284, 0.1203 , -0.1313 , -0.04584 , -0.02725 ,\n", - " -0.1277 , -0.04236 , -0.08466 , -0.0861 , 0.1131 , 0.02806 ,\n", - " -0.0947 , 0.04388 , 0.04263 , 0.03598 , -0.06866 , -0.06018 ,\n", - " -0.02763 , -0.0972 , 0.11505 , -0.1097 , -0.04166 , 0.0742 ,\n", - " -0.06683 , -0.02188 , -0.1663 , -0.0902 , 0.02594 , -0.03802 ,\n", - " -0.034 , -0.04828 , -0.05765 , 0.0633 , -0.02515 , -0.08826 ,\n", - " -0.09753 , -0.10974 , -0.074 , -0.02083 , -0.1301 , 0.1383 ,\n", - " 0.1428 , 0.0935 , 0.0949 , 0.03876 , 0.08514 , -0.12256 ,\n", - " -0.0451 , -0.002306], dtype=float16)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "image_embeddings_onnx[0][0][0]" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "2c52a4d7d83aeda7", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-20T16:05:32.115768Z", - "start_time": "2024-11-20T16:05:32.090218Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[array([[ 0.01532745, 0.05117798, 0.05947876, ..., -0.12255859,\n", - " -0.04510498, -0.00230598],\n", - " [-0.22009277, -0.11071777, 0.04562378, ..., 0.00257111,\n", - " -0.06988525, 0.12384033],\n", - " [-0.18371582, -0.13085938, 0.00393677, ..., -0.02949524,\n", - " -0.05444336, 0.1295166 ],\n", - " ...,\n", - " [-0.1418457 , 0.01023102, 0.1239624 , ..., -0.00460434,\n", - " 0.17321777, 0.09454346],\n", - " [-0.24572754, -0.06878662, 0.11834717, ..., -0.02763367,\n", - " -0.03022766, 0.08917236],\n", - " [-0.2211914 , -0.04171753, 0.19519043, ..., -0.01535797,\n", - " -0.02432251, -0.03561401]], dtype=float32),\n", - " array([[-2.49877930e-01, -1.11511230e-01, -6.51855469e-02, ...,\n", - " 3.19519043e-02, 3.44543457e-02, -1.33666992e-02],\n", - " [-2.20336914e-01, -9.56420898e-02, -1.39694214e-02, ...,\n", - " -8.88705254e-05, -1.57318115e-02, -1.00555420e-02],\n", - " [-2.28271484e-01, -1.14501953e-01, 3.10058594e-02, ...,\n", - " 7.59277344e-02, -4.28466797e-02, 1.19262695e-01],\n", - " ...,\n", - " [-2.04589844e-01, -4.86755371e-02, 8.46557617e-02, ...,\n", - " -3.98254395e-02, 1.66625977e-01, 9.71679688e-02],\n", - " [-2.88085938e-01, -4.50439453e-02, 7.69653320e-02, ...,\n", - " -4.36096191e-02, -1.28784180e-02, 6.26220703e-02],\n", - " [-2.67578125e-01, -3.25317383e-02, 1.66625977e-01, ...,\n", - " -2.90679932e-03, -1.52282715e-02, -3.62243652e-02]], dtype=float32)]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fastembed_i_embeddings" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "786bfac25eb7704a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/experiments/colpali_text_test.ipynb b/experiments/colpali_text_test.ipynb deleted file mode 100644 index b3392089..00000000 --- a/experiments/colpali_text_test.ipynb +++ /dev/null @@ -1,444 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "id": "54b3bfd4ad5b9ee6", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-28T10:44:32.841758Z", - "start_time": "2024-11-28T10:44:32.830025Z" - } - }, - "outputs": [], - "source": [ - "# Your inputs\n", - "queries = [\n", - " # \"Is attention really all you need?\",\n", - " # \"Are Benjamin, Antoine, Merve, and Jo best friends?\",\n", - " # \"Long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long\"\n", - " \"hello world\",\n", - " \"flag embedding\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "86ee1b68fb88b11d", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-27T22:54:23.016952Z", - "start_time": "2024-11-27T22:33:14.872976Z" - } - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f1b463d5ae47404f951fecc6629e8008", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 7 files: 0%| | 0/7 [00:00 Date: Fri, 29 Nov 2024 19:37:22 +0100 Subject: [PATCH 23/27] Fix dependency back + non-needed changes in tests --- pyproject.toml | 2 +- tests/test_image_onnx_embeddings.py | 16 +++------------- tests/test_text_onnx_embeddings.py | 22 +++------------------- 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 64475d5a..39ecf6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ onnxruntime = ">=1.17.0,<1.20.0" tqdm = "^4.66" requests = "^2.31" tokenizers = ">=0.15,<1.0" -huggingface-hub = {version = ">=0.20,<1.0", extras = ["hf_transfer"]} +huggingface-hub = ">=0.20,<1.0" loguru = "^0.7.2" numpy = [ { version = ">=1.21", python = "<3.12" }, diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 90ca0654..27d8d13b 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -21,9 +21,6 @@ "Qdrant/Unicom-ViT-B-32": np.array( [0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186] ), - "akshayballal/colpali-v1.2-merged": np.array( - [0.01533, 0.05118, 0.05948, 0.02583, -0.06128, -0.02682] - ), } @@ -49,16 +46,9 @@ def test_embedding(): canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] - if isinstance(dim, tuple): - assert embeddings.shape == (len(images), *dim) - assert np.allclose( - embeddings[0][0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] - else: - assert embeddings.shape == (len(images), dim) - assert np.allclose( - embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc["model"] assert np.allclose(embeddings[1], embeddings[2]), model_desc["model"] diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 28ac7553..4e14d61b 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -64,15 +64,6 @@ ), "snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]), "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), - "akshayballal/colpali-v1.2-merged": [ - 0.1581, - -0.03748, - 0.09265, - -0.0002161, - 0.0762, - 0.02055, - 0.09937, - ], } @@ -92,16 +83,9 @@ def test_embedding(): canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] - if isinstance(dim, tuple): - assert embeddings.shape == (len(docs), *dim) - assert np.allclose( - embeddings[0][0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] - else: - assert embeddings.shape == (len(docs), dim) - assert np.allclose( - embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 - ), model_desc["model"] + assert np.allclose( + embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 + ), model_desc["model"] if is_ci: delete_model_cache(model.model._model_dir) From 36844627085eb8bd1a9534f67bec9dfc9835d076 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 19:39:28 +0100 Subject: [PATCH 24/27] Fix dependency back + non-needed changes in tests --- tests/test_image_onnx_embeddings.py | 2 +- tests/test_text_onnx_embeddings.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 27d8d13b..5fbf8e3f 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -43,7 +43,7 @@ def test_embedding(): ] embeddings = list(model.embed(images)) embeddings = np.stack(embeddings, axis=0) - + assert embeddings.shape == (len(images), dim) canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] assert np.allclose( diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 4e14d61b..418fb8eb 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -74,15 +74,12 @@ def test_embedding(): if not is_ci and model_desc["size_in_GB"] > 1: continue - dim = model_desc["dim"] - model = TextEmbedding(model_name=model_desc["model"]) docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] - assert np.allclose( embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 ), model_desc["model"] From 4d856c644c3576f8f4a7288b124cf3f454c218d3 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 19:40:13 +0100 Subject: [PATCH 25/27] Fix dependency back + non-needed changes in tests --- tests/test_text_onnx_embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 418fb8eb..e8b115f1 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -73,12 +73,12 @@ def test_embedding(): for model_desc in TextEmbedding.list_supported_models(): if not is_ci and model_desc["size_in_GB"] > 1: continue - + dim = model_desc["dim"] model = TextEmbedding(model_name=model_desc["model"]) docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) - + assert embeddings.shape == (2, dim) canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] assert np.allclose( embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 From b6a51c0086be4b18099b59f7a76e376a4b34aee7 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 19:40:42 +0100 Subject: [PATCH 26/27] Fix dependency back + non-needed changes in tests --- tests/test_text_onnx_embeddings.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e8b115f1..f576330c 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -73,12 +73,15 @@ def test_embedding(): for model_desc in TextEmbedding.list_supported_models(): if not is_ci and model_desc["size_in_GB"] > 1: continue + dim = model_desc["dim"] + model = TextEmbedding(model_name=model_desc["model"]) docs = ["hello world", "flag embedding"] embeddings = list(model.embed(docs)) embeddings = np.stack(embeddings, axis=0) assert embeddings.shape == (2, dim) + canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] assert np.allclose( embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 From b9eebef3a90e0210b746738d5a00395b03f0e0d7 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Mon, 9 Dec 2024 19:16:27 +0100 Subject: [PATCH 27/27] Docstrings for colpali --- fastembed/late_interaction/colpali.py | 129 +++++++++++++++++++++++++- 1 file changed, 126 insertions(+), 3 deletions(-) diff --git a/fastembed/late_interaction/colpali.py b/fastembed/late_interaction/colpali.py index 3f557ac8..63c5eb90 100644 --- a/fastembed/late_interaction/colpali.py +++ b/fastembed/late_interaction/colpali.py @@ -52,18 +52,37 @@ def _post_process_onnx_output( self, output: OnnxOutputContext, ) -> Iterable[np.ndarray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[np.ndarray]: Post-processed output as NumPy arrays. + """ return output.model_output.astype(np.float32) @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: - """Lists the supported models. + """ + Lists the supported models. Returns: list[dict[str, Any]]: A list of dictionaries containing the model information. """ return supported_colpali_models - def _preprocess_queries(self, documents: list[str]): + def _preprocess_queries(self, documents: list[str]) -> list[str]: + """ + Preprocess the input text queries by adding special tokens and padding. + + Args: + documents (list[str]): List of text queries. + + Returns: + list[str]: Preprocessed text queries. + """ texts_query: list[str] = [] for query in documents: @@ -76,6 +95,16 @@ def _preprocess_queries(self, documents: list[str]): def _preprocess_images_input( self, inputs: list[Union[ImageInput]], **kwargs: Any ) -> dict[str, np.ndarray]: + """ + Preprocess the input images for ONNX model inference. + + Args: + inputs (list[Union[ImageInput]]): List of image inputs. + **kwargs: Additional preprocessing arguments. + + Returns: + dict[str, np.ndarray]: Preprocessed image inputs as a dictionary. + """ with contextlib.ExitStack(): image_files = [ Image.open(image) if not isinstance(image, Image.Image) else image @@ -94,6 +123,19 @@ def embed( is_doc: bool = False, **kwargs, ) -> OnnxOutputContext: + """ + Generate embeddings for the given input, either images or text. + + Args: + inputs (Union[ImageInput, str]): Input data (images or text). + batch_size (int, optional): Batch size for embedding. Defaults to 16. + parallel (Optional[int], optional): Number of parallel threads. Defaults to None. + is_doc (bool, optional): Indicates if input is a document. Defaults to False. + **kwargs: Additional arguments for embedding. + + Yields: + OnnxOutputContext: Embedding output context. + """ if is_doc: yield from self._embed_documents( model_name=self.model_name, @@ -121,12 +163,32 @@ def embed( ) def onnx_embed(self, inputs: Union[ImageInput, str], **kwargs) -> OnnxOutputContext: + """ + Embed inputs using the ONNX model. + + Args: + inputs (Union[ImageInput, str]): Input data (images or text). + **kwargs: Additional arguments for embedding. + + Returns: + OnnxOutputContext: Embedding output context. + """ if isinstance(inputs[0], str): return self.onnx_embed_text(inputs, **kwargs) else: return self.onnx_embed_image(inputs, **kwargs) def onnx_embed_image(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: + """ + Embed images using the ONNX model. + + Args: + images (List[ImageInput]): List of image inputs. + **kwargs: Additional arguments for embedding. + + Returns: + OnnxOutputContext: Embedding output context for images. + """ with contextlib.ExitStack(): image_files = [ Image.open(image) if not isinstance(image, Image.Image) else image @@ -144,6 +206,16 @@ def onnx_embed_text( documents: List[str], **kwargs, ) -> OnnxOutputContext: + """ + Embed text using the ONNX model. + + Args: + documents (List[str]): List of text documents. + **kwargs: Additional arguments for embedding. + + Returns: + OnnxOutputContext: Embedding output context for text. + """ documents = self._preprocess_queries(documents) encoded = self.tokenize(documents, **kwargs) input_ids = np.array([self.QUERY_MARKER_TOKEN_ID + e.ids[2:] for e in encoded]) @@ -162,6 +234,16 @@ def onnx_embed_text( def _preprocess_onnx_image_input( self, onnx_input: Dict[str, np.ndarray], **kwargs ) -> Dict[str, np.ndarray]: + """ + Add placeholders for text input when processing image data for ONNX. + + Args: + onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs. + **kwargs: Additional arguments. + + Returns: + Dict[str, np.ndarray]: ONNX input with text placeholders. + """ onnx_input["input_ids"] = np.array( [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] ) @@ -173,6 +255,16 @@ def _preprocess_onnx_image_input( def _preprocess_onnx_text_input( self, onnx_input: Dict[str, np.ndarray], **kwargs ) -> Dict[str, np.ndarray]: + """ + Add placeholders for image input when processing text data for ONNX. + + Args: + onnx_input (Dict[str, np.ndarray]): Preprocessed text inputs. + **kwargs: Additional arguments. + + Returns: + Dict[str, np.ndarray]: ONNX input with image placeholders. + """ empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) onnx_input["pixel_values"] = np.array( [empty_image_placeholder for _ in onnx_input["input_ids"]] @@ -192,6 +284,20 @@ def __init__( device_id: Optional[int] = None, **kwargs, ): + """ + Initialize the ColPali model. + + Args: + model_name (str): Name of the model to use. + cache_dir (Optional[str], optional): Directory for caching model files. Defaults to None. + threads (Optional[int], optional): Number of threads for inference. Defaults to None. + providers (Optional[Sequence[OnnxProvider]], optional): ONNX providers for model execution. Defaults to None. + cuda (bool, optional): Whether to use CUDA for inference. Defaults to False. + device_ids (Optional[list[int]], optional): List of CUDA device IDs. Defaults to None. + lazy_load (bool, optional): Whether to lazily load the model. Defaults to False. + device_id (Optional[int], optional): Specific device ID to use. Defaults to None. + **kwargs: Additional arguments for model initialization. + """ super().__init__(model_name, cache_dir, threads, **kwargs) self.model_description = self._get_model_description(model_name) self._model_dir = self.download_model( @@ -214,7 +320,7 @@ def __init__( def load_onnx_model(self) -> None: """ - Load the onnx model. + Load the ONNX model for inference. """ self._load_onnx_model( model_dir=self._model_dir, @@ -228,11 +334,28 @@ def load_onnx_model(self) -> None: @classmethod def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: + """ + Get the worker class for text/image embedding. + + Returns: + Type[TextEmbeddingWorker]: The worker class. + """ return ColPaliEmbeddingWorker class ColPaliEmbeddingWorker(TextEmbeddingWorker): def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + """ + Initialize the ColPali embedding worker. + + Args: + model_name (str): Name of the model to use. + cache_dir (str): Directory for caching model files. + **kwargs: Additional arguments for initialization. + + Returns: + ColPali: Initialized ColPali model instance. + """ return ColPali( model_name=model_name, cache_dir=cache_dir,