From c9172015298247342bd9c2f545756d9e9aedd00a Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Mon, 23 Dec 2024 09:51:53 +0200 Subject: [PATCH] nit: Remove cache dir from tests --- tests/test_text_multitask_embeddings.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_text_multitask_embeddings.py b/tests/test_text_multitask_embeddings.py index 606b4138..67964f62 100644 --- a/tests/test_text_multitask_embeddings.py +++ b/tests/test_text_multitask_embeddings.py @@ -74,7 +74,7 @@ def test_batch_embedding(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} default task") @@ -105,7 +105,7 @@ def test_single_embedding(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) for task in CANONICAL_VECTOR_VALUES[model_name]: print(f"evaluating {model_name} task_id: {task['task_id']}") @@ -138,7 +138,7 @@ def test_single_embedding_query(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} query_embed task_id: {task_id}") @@ -170,7 +170,7 @@ def test_single_embedding_passage(): if model_name not in CANONICAL_VECTOR_VALUES.keys(): continue - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) print(f"evaluating {model_name} passage_embed task_id: {task_id}") @@ -196,7 +196,7 @@ def test_parallel_processing(): model_name = "jinaai/jina-embeddings-v3" dim = 1024 - model = TextEmbedding(model_name=model_name, cache_dir="models") + model = TextEmbedding(model_name=model_name) embeddings = list(model.embed(docs, batch_size=10, parallel=2)) embeddings = np.stack(embeddings, axis=0) @@ -221,7 +221,7 @@ def test_parallel_processing(): ) def test_lazy_load(model_name): is_ci = os.getenv("CI") - model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models") + model = TextEmbedding(model_name=model_name, lazy_load=True) assert not hasattr(model.model, "model") list(model.embed(docs))