Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(draft) colpali WIP #394

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0583425
colpali WIP. Works for texts, does not work for images
I8dNLo Nov 8, 2024
f63f77c
WIP
I8dNLo Nov 19, 2024
c4bd2c0
Image part ColPali ✅
I8dNLo Nov 19, 2024
ee62c69
WIP Tokenizer part ColPali
I8dNLo Nov 20, 2024
3c36c28
Merge branch 'main' into colpali
I8dNLo Nov 20, 2024
34f557c
Done: Tokenizer part ColPali
I8dNLo Nov 20, 2024
9274c7d
Done: Tests
I8dNLo Nov 20, 2024
d4f4e5a
Remove unnecessary changes
I8dNLo Nov 20, 2024
7ca807e
Description changes
I8dNLo Nov 21, 2024
317ccec
Refactoring of magic numbers and values
I8dNLo Nov 21, 2024
d581de9
Refactoring to late interaction class
I8dNLo Nov 27, 2024
e43f680
Refactoring to late interaction class
I8dNLo Nov 27, 2024
423bb28
fix: Minor fix related to image.image
hh-space-invader Nov 27, 2024
dcae3ab
Moved colpali to late_interaction
I8dNLo Nov 28, 2024
62a065e
Removed colpali from text/image
I8dNLo Nov 28, 2024
68ce437
Merge remote-tracking branch 'myfork/colpali' into colpali
I8dNLo Nov 28, 2024
c040120
Tests draft
I8dNLo Nov 28, 2024
367178b
Tests draft
I8dNLo Nov 29, 2024
1fff39b
Tests draft
I8dNLo Nov 29, 2024
667eee1
Tests another draft
I8dNLo Nov 29, 2024
a9bddf3
Reduce batch size for test
I8dNLo Nov 29, 2024
e747c34
CI pre-clean up
I8dNLo Nov 29, 2024
0ff8f49
Clean_up was a lie
I8dNLo Nov 29, 2024
274e0d7
Clean up non-needed notebooks
I8dNLo Nov 29, 2024
054873a
Fix dependency back + non-needed changes in tests
I8dNLo Nov 29, 2024
3684462
Fix dependency back + non-needed changes in tests
I8dNLo Nov 29, 2024
4d856c6
Fix dependency back + non-needed changes in tests
I8dNLo Nov 29, 2024
b6a51c0
Fix dependency back + non-needed changes in tests
I8dNLo Nov 29, 2024
b9eebef
Docstrings for colpali
I8dNLo Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions fastembed/image/colpali_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import contextlib
from typing import Any, Dict, Iterable, List

import numpy as np
from PIL import Image

from fastembed.common import ImageInput
from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.image.onnx_embedding import OnnxImageEmbedding

supported_onnx_models = [
{
"model": "akshayballal/colpali-v1.2-merged",
"dim": (1030, 128),
"description": "Image embeddings, Unimodal (image), Aligned to text latent space via PaliGemma-3B, 512 patches max, 2024.",
"license": "mit",
"size_in_GB": 6.08,
"sources": {
"hf": "akshayballal/colpali-v1.2-merged-onnx",
},
"additional_files": ["model.onnx_data"],
"model_file": "model.onnx",
}
]


class ColpaliImageModel(OnnxImageEmbedding):
def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], **kwargs
) -> Dict[str, np.ndarray]:
empty_text_placeholder = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108])
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
even_attention_mask = np.array([1] * 1030)
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
onnx_input["input_ids"] = np.array(
[empty_text_placeholder for _ in onnx_input["input_ids"]]
)
onnx_input["attention_mask"] = np.array(
[even_attention_mask for _ in onnx_input["input_ids"]]
)
return onnx_input

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""
Lists the supported models.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
return output.model_output.astype(np.float32)

def onnx_embed(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext:
with contextlib.ExitStack():
image_files = [
Image.open(image) if not isinstance(image, Image.Image) else image
for image in images
]
encoded = self.processor(image_files)
onnx_input = self._build_onnx_input(encoded)
onnx_input = self._preprocess_onnx_input(onnx_input)

model_output = self.model.run(None, onnx_input)
embeddings = model_output[0].reshape(len(images), *supported_onnx_models[0]["dim"])
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
return OnnxOutputContext(model_output=embeddings)
3 changes: 2 additions & 1 deletion fastembed/image/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from fastembed.common import ImageInput, OnnxProvider
from fastembed.image.image_embedding_base import ImageEmbeddingBase
from fastembed.image.onnx_embedding import OnnxImageEmbedding
from fastembed.image.colpali_model import ColpaliImageModel


class ImageEmbedding(ImageEmbeddingBase):
EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding]
EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding, ColpaliImageModel]

I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
Expand Down
4 changes: 2 additions & 2 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]):
@staticmethod
def _get_resize(transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if mode == "CLIPImageProcessor" or mode == "SiglipImageProcessor":
if config.get("do_resize", False):
size = config["size"]
if "shortest_edge" in size:
Expand Down Expand Up @@ -161,7 +161,7 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]):
@staticmethod
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
if mode == "CLIPImageProcessor" or mode == "SiglipImageProcessor":
if config.get("do_center_crop", False):
crop_size = config["crop_size"]
if isinstance(crop_size, int):
Expand Down
86 changes: 86 additions & 0 deletions fastembed/text/colpali_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Any, Dict, Iterable, List

import numpy as np

from fastembed.common.onnx_model import OnnxOutputContext
from fastembed.text.onnx_embedding import OnnxTextEmbedding

supported_onnx_models = [
{
"model": "akshayballal/colpali-v1.2-merged",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use only text part of the model here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really separated in our quantization. We could re-quantize it and separate, but I can not say it will have original behaviour. It could, but I'm not sure 100%

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which image model is used?
How much does it weight?
Does it make sense to split the models or the difference is negligible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma 2B

"dim": (16, 128),
"description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.",
"license": "mit",
"size_in_GB": 6.08,
"sources": {
"hf": "akshayballal/colpali-v1.2-merged-onnx",
},
"additional_files": [
"model.onnx_data",
"tokenizer.json",
"tokenizer_config.json",
"config.json",
],
"model_file": "model.onnx",
}
]


class ColpaliTextModel(OnnxTextEmbedding):
query_prefix = "Query: "
bos_token = "<s>"
pad_token = "<pad>"

def _preprocess_onnx_input(
self, onnx_input: Dict[str, np.ndarray], **kwargs
) -> Dict[str, np.ndarray]:
empty_image_placeholder = np.zeros((3, 448, 448), dtype=np.float32)
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
onnx_input["pixel_values"] = np.array(
[empty_image_placeholder for _ in onnx_input["input_ids"]]
)
onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]])
return onnx_input

@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
"""
Lists the supported models.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the model information.
"""
return supported_onnx_models

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
return output.model_output.astype(np.float32)

def _preprocess_queries(self, documents: List[str]):
texts_query: List[str] = []

for query in documents:
query = self.bos_token + self.query_prefix + query + self.pad_token * 10
query += "\n"

texts_query.append(query)
return texts_query

def onnx_embed(
self,
documents: List[str],
**kwargs,
) -> OnnxOutputContext:
documents = self._preprocess_queries(documents)
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer.enable_truncation(max_length=10000)
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved
encoded = self.tokenize(documents, **kwargs)
input_ids = np.array([[2, 9413] + e.ids[2:] for e in encoded])
I8dNLo marked this conversation as resolved.
Show resolved Hide resolved

attention_mask = np.array([e.attention_mask for e in encoded])
onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)}
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
onnx_input["attention_mask"] = attention_mask
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
return OnnxOutputContext(
model_output=model_output[0],
attention_mask=onnx_input.get("attention_mask", attention_mask),
input_ids=onnx_input.get("input_ids", input_ids),
)
1 change: 0 additions & 1 deletion fastembed/text/onnx_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def onnx_embed(
onnx_input["token_type_ids"] = np.array(
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
)

onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)

model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input)
Expand Down
17 changes: 13 additions & 4 deletions tests/test_image_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
"Qdrant/Unicom-ViT-B-32": np.array(
[0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186]
),
"akshayballal/colpali-v1.2-merged": np.array(
[0.01533, 0.05118, 0.05948, 0.02583, -0.06128, -0.02682]
),
}


Expand All @@ -43,13 +46,19 @@ def test_embedding():
]
embeddings = list(model.embed(images))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (len(images), dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]]

assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc["model"]
if isinstance(dim, tuple):
assert embeddings.shape == (len(images), *dim)
assert np.allclose(
embeddings[0][0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc["model"]
else:
assert embeddings.shape == (len(images), dim)
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc["model"]

assert np.allclose(embeddings[1], embeddings[2]), model_desc["model"]

Expand Down
24 changes: 20 additions & 4 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@
),
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"akshayballal/colpali-v1.2-merged": [
0.1581,
-0.03748,
0.09265,
-0.0002161,
0.0762,
0.02055,
0.09937,
],
}


Expand All @@ -80,12 +89,19 @@ def test_embedding():
docs = ["hello world", "flag embedding"]
embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
assert embeddings.shape == (2, dim)

canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]]
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc["model"]

if isinstance(dim, tuple):
assert embeddings.shape == (len(docs), *dim)
assert np.allclose(
embeddings[0][0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc["model"]
else:
assert embeddings.shape == (len(docs), dim)
assert np.allclose(
embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3
), model_desc["model"]
if is_ci:
delete_model_cache(model.model._model_dir)

Expand Down
Loading