From 87ce1aa65c3a6ae9da4e82c70de9ed168ba9fa1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20S=C3=BCberkr=C3=BCb?= Date: Mon, 29 Jan 2024 15:11:00 +0100 Subject: [PATCH] Format test_embed --- tests/test_embed.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/test_embed.py b/tests/test_embed.py index dc4751c..34b46f5 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -10,7 +10,8 @@ from aleph_alpha_client.embedding import ( BatchSemanticEmbeddingRequest, SemanticEmbeddingRequest, - SemanticRepresentation, BatchSemanticEmbeddingResponse, + SemanticRepresentation, + BatchSemanticEmbeddingResponse, ) from aleph_alpha_client.prompt import Prompt from tests.common import ( @@ -61,7 +62,9 @@ async def test_batch_embed_semantic_with_async_client( ): words = ["car", "elephant", "kitchen sink", "rubber", "sun"] r = random.Random(4082) - prompts = list([Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)]) + prompts = list( + [Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)] + ) request = BatchSemanticEmbeddingRequest( prompts=prompts, @@ -69,7 +72,10 @@ async def test_batch_embed_semantic_with_async_client( compress_to_size=128, ) result = await async_client.batch_semantic_embed( - request=request, num_concurrent_requests=10, batch_size=batch_size, model="luminous-base" + request=request, + num_concurrent_requests=10, + batch_size=batch_size, + model="luminous-base", ) # We have no control over the exact tokenizer used in the backend, so we cannot know the exact @@ -127,7 +133,11 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ } httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) - ).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json()) + ).respond_with_json( + BatchSemanticEmbeddingResponse( + model_version="1", embeddings=[], num_tokens_prompt_total=1 + ).to_json() + ) async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1) await async_client.batch_semantic_embed(request, model=model_name) @@ -226,6 +236,10 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer): expected_body = {**request.to_json(), "model": model_name} httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) - ).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json()) + ).respond_with_json( + BatchSemanticEmbeddingResponse( + model_version="1", embeddings=[], num_tokens_prompt_total=1 + ).to_json() + ) sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1) sync_client.batch_semantic_embed(request, model=model_name)