diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index b951b34..4dbbdca 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -434,7 +434,7 @@ def batch_semantic_embed( responses: List[EmbeddingVector] = [] model_version = "" - num_prompt_tokens_total = 0 + num_tokens_prompt_total = 0 # The API currently only supports batch semantic embedding requests with up to 100 # prompts per batch. As a convenience for users, this function chunks larger requests. for batch_request in _generate_semantic_embedding_batches(request): @@ -446,10 +446,10 @@ def batch_semantic_embed( response = BatchSemanticEmbeddingResponse.from_json(raw_response) model_version = response.model_version responses.extend(response.embeddings) - num_prompt_tokens_total += response.num_prompt_tokens_total + num_tokens_prompt_total += response.num_tokens_prompt_total return BatchSemanticEmbeddingResponse( - model_version=model_version, embeddings=responses, num_prompt_tokens_total=num_prompt_tokens_total + model_version=model_version, embeddings=responses, num_tokens_prompt_total=num_tokens_prompt_total ) def evaluate( @@ -973,15 +973,15 @@ async def batch_semantic_embed( _generate_semantic_embedding_batches(request, batch_size), progress_bar, ) - num_prompt_tokens_total = 0 + num_tokens_prompt_total = 0 for result in results: resp = BatchSemanticEmbeddingResponse.from_json(result) model_version = resp.model_version responses.extend(resp.embeddings) - num_prompt_tokens_total += resp.num_prompt_tokens_total + num_tokens_prompt_total += resp.num_tokens_prompt_total return BatchSemanticEmbeddingResponse( - model_version=model_version, embeddings=responses, num_prompt_tokens_total=num_prompt_tokens_total + model_version=model_version, embeddings=responses, num_tokens_prompt_total=num_tokens_prompt_total ) async def evaluate( diff --git a/aleph_alpha_client/embedding.py b/aleph_alpha_client/embedding.py index a7a9301..29abaeb 100644 --- a/aleph_alpha_client/embedding.py +++ b/aleph_alpha_client/embedding.py @@ -88,7 +88,7 @@ def _asdict(self) -> Mapping[str, Any]: @dataclass(frozen=True) class EmbeddingResponse: model_version: str - num_prompt_tokens_total: int + num_tokens_prompt_total: int embeddings: Optional[Dict[Tuple[str, str], List[float]]] tokens: Optional[List[str]] message: Optional[str] = None @@ -104,7 +104,7 @@ def from_json(json: Dict[str, Any]) -> "EmbeddingResponse": }, tokens=json.get("tokens"), message=json.get("message"), - num_prompt_tokens_total=json.get("num_prompt_tokens_total", 0) + num_tokens_prompt_total=json.get("num_tokens_prompt_total", 0) ) @@ -291,7 +291,7 @@ class SemanticEmbeddingResponse: model_version: str embedding: EmbeddingVector - num_prompt_tokens_total: int + num_tokens_prompt_total: int message: Optional[str] = None @staticmethod @@ -300,7 +300,7 @@ def from_json(json: Dict[str, Any]) -> "SemanticEmbeddingResponse": model_version=json["model_version"], embedding=json["embedding"], message=json.get("message"), - num_prompt_tokens_total=json.get("num_prompt_tokens_total", 0) + num_tokens_prompt_total=json.get("num_tokens_prompt_total", 0) ) @@ -318,18 +318,18 @@ class BatchSemanticEmbeddingResponse: model_version: str embeddings: Sequence[EmbeddingVector] - num_prompt_tokens_total: int + num_tokens_prompt_total: int @staticmethod def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse": return BatchSemanticEmbeddingResponse( - model_version=json["model_version"], embeddings=json["embeddings"], num_prompt_tokens_total=json.get("num_prompt_tokens_total", 0) + model_version=json["model_version"], embeddings=json["embeddings"], num_tokens_prompt_total=json.get("num_tokens_prompt_total", 0) ) @staticmethod def _from_model_version_and_embeddings( - model_version: str, embeddings: Sequence[EmbeddingVector], num_prompt_tokens_total: int + model_version: str, embeddings: Sequence[EmbeddingVector], num_tokens_prompt_total: int ) -> "BatchSemanticEmbeddingResponse": return BatchSemanticEmbeddingResponse( - model_version=model_version, embeddings=embeddings, num_prompt_tokens_total=num_prompt_tokens_total + model_version=model_version, embeddings=embeddings, num_tokens_prompt_total=num_tokens_prompt_total ) diff --git a/tests/test_embed.py b/tests/test_embed.py index 4e1523c..3a84e3a 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -34,7 +34,7 @@ async def test_can_embed_with_async_client(async_client: AsyncClient, model_name request.pooling ) * len(request.layers) assert response.tokens is not None - assert response.num_prompt_tokens_total == 1 + assert response.num_tokens_prompt_total == 2 @pytest.mark.system_test @@ -51,7 +51,7 @@ async def test_can_semantic_embed_with_async_client( assert response.model_version is not None assert response.embedding assert len(response.embedding) == 128 - assert response.num_prompt_tokens_total == 1 + assert response.num_tokens_prompt_total == 1 @pytest.mark.parametrize("num_prompts", [1, 100, 101]) @@ -73,7 +73,7 @@ async def test_batch_embed_semantic_with_async_client( request=request, num_concurrent_requests=10, batch_size=batch_size ) num_tokens = sum([len((await t).tokens) for t in tokens]) - assert result.num_prompt_tokens_total == num_tokens + assert result.num_tokens_prompt_total == num_tokens assert len(result.embeddings) == num_prompts # To make sure that the ordering of responses is preserved, @@ -147,7 +147,7 @@ def test_embed(sync_client: Client, model_name: str): request.layers ) assert result.tokens is None - assert result.num_prompt_tokens_total == 1 + assert result.num_tokens_prompt_total == 1 @pytest.mark.system_test @@ -184,7 +184,7 @@ def test_embed_with_tokens(sync_client: Client, model_name: str): request.layers ) assert result.tokens is not None - assert result.num_prompt_tokens_total == 1 + assert result.num_tokens_prompt_total == 2 @pytest.mark.system_test @@ -200,7 +200,7 @@ def test_embed_semantic(sync_client: Client): assert result.model_version is not None assert result.embedding assert len(result.embedding) == 128 - assert result.num_prompt_tokens_total == 1 + assert result.num_tokens_prompt_total == 1 @pytest.mark.parametrize("num_prompts", [1, 100, 101, 200, 1000])