Skip to content

Commit

Permalink
WIP: Added jina clip text embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Nov 19, 2024
1 parent 1343e55 commit d5b1f70
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 52 deletions.
11 changes: 11 additions & 0 deletions fastembed/text/clip_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
},
"model_file": "model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
},
]


Expand Down
116 changes: 64 additions & 52 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import numpy as np
import pytest

from fastembed.text.text_embedding import TextEmbedding
from tests.utils import delete_model_cache
Expand Down Expand Up @@ -62,19 +61,32 @@
),
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
}


def test_embedding():
is_ci = os.getenv("CI")

for model_desc in TextEmbedding.list_supported_models():
for model_desc in [
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
}
]:
if not is_ci and model_desc["size_in_GB"] > 1:
continue

dim = model_desc["dim"]

model = TextEmbedding(model_name=model_desc["model"])
model = TextEmbedding(model_name=model_desc["model"], cache_dir="models")
docs = ["hello world", "flag embedding"]
embeddings = list(model.embed(docs))
embeddings = np.stack(embeddings, axis=0)
Expand All @@ -88,66 +100,66 @@ def test_embedding():
delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"n_dims,model_name",
[(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
)
def test_batch_embedding(n_dims, model_name):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name)
# @pytest.mark.parametrize(
# "n_dims,model_name",
# [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
# )
# def test_batch_embedding(n_dims, model_name):
# is_ci = os.getenv("CI")
# model = TextEmbedding(model_name=model_name)

docs = ["hello world", "flag embedding"] * 100
embeddings = list(model.embed(docs, batch_size=10))
embeddings = np.stack(embeddings, axis=0)
# docs = ["hello world", "flag embedding"] * 100
# embeddings = list(model.embed(docs, batch_size=10))
# embeddings = np.stack(embeddings, axis=0)

assert embeddings.shape == (200, n_dims)
if is_ci:
delete_model_cache(model.model._model_dir)
# assert embeddings.shape == (200, n_dims)
# if is_ci:
# delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"n_dims,model_name",
[(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
)
def test_parallel_processing(n_dims, model_name):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name)
# @pytest.mark.parametrize(
# "n_dims,model_name",
# [(384, "BAAI/bge-small-en-v1.5"), (768, "jinaai/jina-embeddings-v2-base-en")],
# )
# def test_parallel_processing(n_dims, model_name):
# is_ci = os.getenv("CI")
# model = TextEmbedding(model_name=model_name)

docs = ["hello world", "flag embedding"] * 100
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
embeddings = np.stack(embeddings, axis=0)
# docs = ["hello world", "flag embedding"] * 100
# embeddings = list(model.embed(docs, batch_size=10, parallel=2))
# embeddings = np.stack(embeddings, axis=0)

embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None))
embeddings_2 = np.stack(embeddings_2, axis=0)
# embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None))
# embeddings_2 = np.stack(embeddings_2, axis=0)

embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
embeddings_3 = np.stack(embeddings_3, axis=0)
# embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
# embeddings_3 = np.stack(embeddings_3, axis=0)

assert embeddings.shape == (200, n_dims)
assert np.allclose(embeddings, embeddings_2, atol=1e-3)
assert np.allclose(embeddings, embeddings_3, atol=1e-3)
# assert embeddings.shape == (200, n_dims)
# assert np.allclose(embeddings, embeddings_2, atol=1e-3)
# assert np.allclose(embeddings, embeddings_3, atol=1e-3)

if is_ci:
delete_model_cache(model.model._model_dir)
# if is_ci:
# delete_model_cache(model.model._model_dir)


@pytest.mark.parametrize(
"model_name",
["BAAI/bge-small-en-v1.5"],
)
def test_lazy_load(model_name):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")
docs = ["hello world", "flag embedding"]
list(model.embed(docs))
assert hasattr(model.model, "model")
# @pytest.mark.parametrize(
# "model_name",
# ["BAAI/bge-small-en-v1.5"],
# )
# def test_lazy_load(model_name):
# is_ci = os.getenv("CI")
# model = TextEmbedding(model_name=model_name, lazy_load=True)
# assert not hasattr(model.model, "model")
# docs = ["hello world", "flag embedding"]
# list(model.embed(docs))
# assert hasattr(model.model, "model")

model = TextEmbedding(model_name=model_name, lazy_load=True)
list(model.query_embed(docs))
# model = TextEmbedding(model_name=model_name, lazy_load=True)
# list(model.query_embed(docs))

model = TextEmbedding(model_name=model_name, lazy_load=True)
list(model.passage_embed(docs))
# model = TextEmbedding(model_name=model_name, lazy_load=True)
# list(model.passage_embed(docs))

if is_ci:
delete_model_cache(model.model._model_dir)
# if is_ci:
# delete_model_cache(model.model._model_dir)

0 comments on commit d5b1f70

Please sign in to comment.