Skip to content

Commit

Permalink
fix: Turn off parallel=0 in tests for models > 2 gb
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Nov 5, 2024
1 parent 52a94d7 commit c37970d
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions tests/test_multi_task_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def test_embedding():
if isinstance(dim, list): # if matryoshka
dim = dim[0]

model = MultiTaskTextEmbedding(model_name=model_name, cache_dir="models")
model = MultiTaskTextEmbedding(model_name=model_name)

for task in model_desc["tasks"]:
if task in CANONICAL_VECTOR_VALUES[model_name].keys():
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_batch_embedding():
if isinstance(dim, list): # if matryoshka
dim = dim[0]

model = MultiTaskTextEmbedding(model_name=model_name, cache_dir="models")
model = MultiTaskTextEmbedding(model_name=model_name)

for task in model_desc["tasks"]:
print(f"evaluating {model_name} task: {task}")
Expand All @@ -463,7 +463,7 @@ def test_matryoshka_embeddings():
is_ci = os.getenv("CI")
embeddings_size = 64

model = MultiTaskTextEmbedding(model_name="jinaai/jina-embeddings-v3", cache_dir="models")
model = MultiTaskTextEmbedding(model_name="jinaai/jina-embeddings-v3")

with pytest.raises(ValueError):
embeddings = list(model.task_embed(docs, "text-matching", embeddings_size=100))
Expand All @@ -480,11 +480,12 @@ def test_matryoshka_embeddings():

def test_parallel_processing():
is_ci = os.getenv("CI")
model = MultiTaskTextEmbedding(model_name="jinaai/jina-embeddings-v3")
model_name = "jinaai/jina-embeddings-v3"
model = MultiTaskTextEmbedding(model_name=model_name)

token_dim = 1024
task_type = "text-matching"
task_tolerance = 1e-3
task_tolerance = 1e-6
docs = ["Hello World", "Follow the white rabbit."] * 100

embeddings = list(model.task_embed(docs, task_type=task_type, batch_size=10, parallel=2))
Expand All @@ -495,13 +496,16 @@ def test_parallel_processing():
embedding_arrays = [e.embedding for e in embeddings_2]
embeddings_2 = np.stack(embedding_arrays, axis=0)

embeddings_3 = list(model.task_embed(docs, task_type=task_type, batch_size=10, parallel=0))
embedding_arrays = [e.embedding for e in embeddings_3]
embeddings_3 = np.stack(embedding_arrays, axis=0)

assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == token_dim
assert np.allclose(embeddings, embeddings_2, atol=task_tolerance)
assert np.allclose(embeddings, embeddings_3, atol=task_tolerance)

if (
is_ci and not model._get_model_description(model_name)["size_in_GB"] > 2
): # might be too big that the ci can handle
embeddings_3 = list(model.task_embed(docs, task_type=task_type, batch_size=10, parallel=0))
embedding_arrays = [e.embedding for e in embeddings_3]
embeddings_3 = np.stack(embedding_arrays, axis=0)
assert np.allclose(embeddings, embeddings_3, atol=task_tolerance)

if is_ci:
shutil.rmtree(model.model._model_dir)
Expand All @@ -512,7 +516,7 @@ def test_parallel_processing():
["jinaai/jina-embeddings-v3"],
)
def test_lazy_load(model_name):
model = MultiTaskTextEmbedding(model_name=model_name, cache_dir="models", lazy_load=True)
model = MultiTaskTextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")

list(model.task_embed(docs, task_type="text-matching"))
Expand Down

0 comments on commit c37970d

Please sign in to comment.