Skip to content

Commit

Permalink
nit: Remove cache dir from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 23, 2024
1 parent 197b381 commit c917201
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/test_text_multitask_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_batch_embedding():
if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name, cache_dir="models")
model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} default task")

Expand Down Expand Up @@ -105,7 +105,7 @@ def test_single_embedding():
if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name, cache_dir="models")
model = TextEmbedding(model_name=model_name)

for task in CANONICAL_VECTOR_VALUES[model_name]:
print(f"evaluating {model_name} task_id: {task['task_id']}")
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_single_embedding_query():
if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name, cache_dir="models")
model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} query_embed task_id: {task_id}")

Expand Down Expand Up @@ -170,7 +170,7 @@ def test_single_embedding_passage():
if model_name not in CANONICAL_VECTOR_VALUES.keys():
continue

model = TextEmbedding(model_name=model_name, cache_dir="models")
model = TextEmbedding(model_name=model_name)

print(f"evaluating {model_name} passage_embed task_id: {task_id}")

Expand All @@ -196,7 +196,7 @@ def test_parallel_processing():
model_name = "jinaai/jina-embeddings-v3"
dim = 1024

model = TextEmbedding(model_name=model_name, cache_dir="models")
model = TextEmbedding(model_name=model_name)

embeddings = list(model.embed(docs, batch_size=10, parallel=2))
embeddings = np.stack(embeddings, axis=0)
Expand All @@ -221,7 +221,7 @@ def test_parallel_processing():
)
def test_lazy_load(model_name):
is_ci = os.getenv("CI")
model = TextEmbedding(model_name=model_name, lazy_load=True, cache_dir="models")
model = TextEmbedding(model_name=model_name, lazy_load=True)
assert not hasattr(model.model, "model")

list(model.embed(docs))
Expand Down

0 comments on commit c917201

Please sign in to comment.