Skip to content

Commit

Permalink
Format test_embed
Browse files Browse the repository at this point in the history
  • Loading branch information
timsueberkrueb committed Jan 29, 2024
1 parent 65688cd commit 87ce1aa
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from aleph_alpha_client.embedding import (
BatchSemanticEmbeddingRequest,
SemanticEmbeddingRequest,
SemanticRepresentation, BatchSemanticEmbeddingResponse,
SemanticRepresentation,
BatchSemanticEmbeddingResponse,
)
from aleph_alpha_client.prompt import Prompt
from tests.common import (
Expand Down Expand Up @@ -61,15 +62,20 @@ async def test_batch_embed_semantic_with_async_client(
):
words = ["car", "elephant", "kitchen sink", "rubber", "sun"]
r = random.Random(4082)
prompts = list([Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)])
prompts = list(
[Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)]
)

request = BatchSemanticEmbeddingRequest(
prompts=prompts,
representation=SemanticRepresentation.Symmetric,
compress_to_size=128,
)
result = await async_client.batch_semantic_embed(
request=request, num_concurrent_requests=10, batch_size=batch_size, model="luminous-base"
request=request,
num_concurrent_requests=10,
batch_size=batch_size,
model="luminous-base",
)

# We have no control over the exact tokenizer used in the backend, so we cannot know the exact
Expand Down Expand Up @@ -127,7 +133,11 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ
}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
).respond_with_json(
BatchSemanticEmbeddingResponse(
model_version="1", embeddings=[], num_tokens_prompt_total=1
).to_json()
)
async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1)
await async_client.batch_semantic_embed(request, model=model_name)

Expand Down Expand Up @@ -226,6 +236,10 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer):
expected_body = {**request.to_json(), "model": model_name}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
).respond_with_json(
BatchSemanticEmbeddingResponse(
model_version="1", embeddings=[], num_tokens_prompt_total=1
).to_json()
)
sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1)
sync_client.batch_semantic_embed(request, model=model_name)

0 comments on commit 87ce1aa

Please sign in to comment.