diff --git a/README.md b/README.md index d2aa57d..4b69c4d 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ pip install aleph-alpha-client ```python -from aleph_alpha_client import ImagePrompt, AlephAlphaModel, AlephAlphaClient, CompletionRequest +from aleph_alpha_client import ImagePrompt, AlephAlphaModel, AlephAlphaClient, CompletionRequest, Prompt import os model = AlephAlphaModel( @@ -33,14 +33,14 @@ model = AlephAlphaModel( url = "https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png" image = ImagePrompt.from_url(url) -prompt = [ +prompt = Prompt([ image, "Q: What does the picture show? A:", -] +]) request = CompletionRequest(prompt=prompt, maximum_tokens=20) result = model.complete(request) -print(result.completions[0]["completion"]) +print(result.completions[0].completion) ``` @@ -48,7 +48,7 @@ print(result.completions[0]["completion"]) ```python -from aleph_alpha_client import AlephAlphaClient, AlephAlphaModel, EvaluationRequest +from aleph_alpha_client import AlephAlphaClient, AlephAlphaModel, EvaluationRequest, Prompt import os model = AlephAlphaModel( @@ -56,7 +56,7 @@ model = AlephAlphaModel( model_name = "luminous-extended" ) -request = EvaluationRequest(prompt="The api works", completion_expected=" well") +request = EvaluationRequest(prompt=Prompt.from_text("The api works"), completion_expected=" well") result = model.evaluate(request) print(result) @@ -69,7 +69,7 @@ print(result) ```python -from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EvaluationRequest +from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EvaluationRequest, Prompt import os model = AlephAlphaModel( @@ -80,10 +80,10 @@ model = AlephAlphaModel( url = "https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/2008-09-24_Blockbuster_in_Durham.jpg/330px-2008-09-24_Blockbuster_in_Durham.jpg" image = ImagePrompt.from_url(url) -prompt = [ +prompt = Prompt([ image, "Q: What is the name of the store?\nA:", -] +]) request = EvaluationRequest(prompt=prompt, completion_expected=" Blockbuster Video") result = model.evaluate(request) @@ -96,7 +96,7 @@ print(result) ```python -from aleph_alpha_client import AlephAlphaModel, AlephAlphaClient, EmbeddingRequest +from aleph_alpha_client import AlephAlphaModel, AlephAlphaClient, EmbeddingRequest, Prompt import os model = AlephAlphaModel( @@ -104,7 +104,7 @@ model = AlephAlphaModel( model_name = "luminous-extended" ) -request = EmbeddingRequest(prompt=["This is an example."], layers=[-1], pooling=["mean"]) +request = EmbeddingRequest(prompt=Prompt.from_text("This is an example."), layers=[-1], pooling=["mean"]) result = model.embed(request) print(result) @@ -116,7 +116,7 @@ print(result) ```python -from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EmbeddingRequest +from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EmbeddingRequest, Prompt import os model = AlephAlphaModel( @@ -127,10 +127,10 @@ model = AlephAlphaModel( url = "https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/2008-09-24_Blockbuster_in_Durham.jpg/330px-2008-09-24_Blockbuster_in_Durham.jpg" image = ImagePrompt.from_url(url) -prompt = [ +prompt = Prompt([ image, "Q: What is the name of the store?\nA:", -] +]) request = EmbeddingRequest(prompt=prompt, layers=[-1], pooling=["mean"]) result = model.embed(request) diff --git a/aleph_alpha_client/__init__.py b/aleph_alpha_client/__init__.py index 3c863e9..20b6f9b 100644 --- a/aleph_alpha_client/__init__.py +++ b/aleph_alpha_client/__init__.py @@ -1,6 +1,7 @@ from .aleph_alpha_client import AlephAlphaClient, QuotaError, POOLING_OPTIONS from .aleph_alpha_model import AlephAlphaModel from .image import ImagePrompt +from .prompt import Prompt from .explanation import ExplanationRequest from .embedding import EmbeddingRequest from .completion import CompletionRequest diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 1d8a446..1837335 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -421,10 +421,16 @@ def qa( return response_json def _explain(self, model: str, request: ExplanationRequest, hosting: Optional[str] = None): - body = request.render_as_body(model, hosting) + body = { + "model": model, + "prompt": [_to_prompt_item(item) for item in request.prompt.items], + "target": request.target, + "suppression_factor": request.suppression_factor, + "directional": request.directional, + "conceptual_suppression_threshold": request.conceptual_suppression_threshold + } response = requests.post(f"{self.host}explain", headers=self.request_headers, json=body) - response_dict = self._translate_errors(response) - return response_dict + return self._translate_errors(response) @staticmethod diff --git a/aleph_alpha_client/aleph_alpha_model.py b/aleph_alpha_client/aleph_alpha_model.py index fd030e1..6602720 100644 --- a/aleph_alpha_client/aleph_alpha_model.py +++ b/aleph_alpha_client/aleph_alpha_model.py @@ -1,7 +1,11 @@ -from typing import Any, Mapping +from collections import ChainMap +from typing import Any, Mapping, Union from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient from aleph_alpha_client.completion import CompletionRequest, CompletionResponse -from aleph_alpha_client.detokenization import DetokenizationRequest, DetokenizationResponse +from aleph_alpha_client.detokenization import ( + DetokenizationRequest, + DetokenizationResponse, +) from aleph_alpha_client.embedding import EmbeddingRequest, EmbeddingResponse from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.explanation import ExplanationRequest @@ -10,35 +14,54 @@ class AlephAlphaModel: - - def __init__(self, client: AlephAlphaClient, model_name: str, hosting: str = "cloud") -> None: + def __init__( + self, client: AlephAlphaClient, model_name: str, hosting: str = "cloud" + ) -> None: self.client = client self.model_name = model_name self.hosting = hosting def complete(self, request: CompletionRequest) -> CompletionResponse: - response_json = self.client.complete(model = self.model_name, hosting=self.hosting, **request._asdict()) + response_json = self.client.complete( + model=self.model_name, hosting=self.hosting, **self.as_request_dict(request) + ) return CompletionResponse.from_json(response_json) def tokenize(self, request: TokenizationRequest) -> TokenizationResponse: - response_json = self.client.tokenize(model = self.model_name, **request._asdict()) + response_json = self.client.tokenize(model=self.model_name, **request._asdict()) return TokenizationResponse.from_json(response_json) def detokenize(self, request: DetokenizationRequest) -> DetokenizationResponse: - response_json = self.client.detokenize(model = self.model_name, **request._asdict()) + response_json = self.client.detokenize( + model=self.model_name, **request._asdict() + ) return DetokenizationResponse.from_json(response_json) def embed(self, request: EmbeddingRequest) -> EmbeddingResponse: - response_json = self.client.embed(model = self.model_name, hosting=self.hosting, **request._asdict()) + response_json = self.client.embed( + model=self.model_name, hosting=self.hosting, **self.as_request_dict(request) + ) return EmbeddingResponse.from_json(response_json) def evaluate(self, request: EvaluationRequest) -> EvaluationResponse: - response_json = self.client.evaluate(model = self.model_name, hosting=self.hosting, **request._asdict()) + response_json = self.client.evaluate( + model=self.model_name, hosting=self.hosting, **self.as_request_dict(request) + ) return EvaluationResponse.from_json(response_json) def qa(self, request: QaRequest) -> QaResponse: - response_json = self.client.qa(model = self.model_name, hosting=self.hosting, **request._asdict()) + response_json = self.client.qa( + model=self.model_name, hosting=self.hosting, **request._asdict() + ) return QaResponse.from_json(response_json) def _explain(self, request: ExplanationRequest) -> Mapping[str, Any]: - return self.client._explain(model = self.model_name, hosting=self.hosting, request=request) + return self.client._explain( + model=self.model_name, hosting=self.hosting, request=request + ) + + @staticmethod + def as_request_dict( + request: Union[CompletionRequest, EmbeddingRequest, EvaluationRequest] + ) -> Mapping[str, Any]: + return ChainMap({"prompt": request.prompt.items}, request._asdict()) diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index 1e2f0d4..d7ac5e7 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union from aleph_alpha_client.image import ImagePrompt -from aleph_alpha_client.prompt import _to_serializable_prompt +from aleph_alpha_client.prompt import Prompt, _to_serializable_prompt class CompletionRequest(NamedTuple): @@ -75,7 +75,7 @@ class CompletionRequest(NamedTuple): Our goal is to improve your results while using our API. But you can always pass disable_optimizations: true and we will leave your prompt and completion untouched. """ - prompt: Sequence[Union[str, ImagePrompt]] + prompt: Prompt maximum_tokens: int = 64 temperature: float = 0.0 top_k: int = 0 @@ -92,28 +92,6 @@ class CompletionRequest(NamedTuple): tokens: bool = False disable_optimizations: bool = False - def render_as_body(self, model: str, hosting: str) -> Dict[str, Any]: - return { - "model": model, - "hosting": hosting, - "prompt": _to_serializable_prompt(self.prompt), - "maximum_tokens": self.maximum_tokens, - "temperature": self.temperature, - "top_k": self.top_k, - "top_p": self.top_p, - "presence_penalty": self.presence_penalty, - "frequency_penalty": self.frequency_penalty, - "best_of": self.best_of, - "n": self.n, - "logit_bias": self.logit_bias, - "log_probs": self.log_probs, - "repetition_penalties_include_prompt": self.repetition_penalties_include_prompt, - "use_multiplicative_presence_penalty": self.use_multiplicative_presence_penalty, - "stop_sequences": self.stop_sequences, - "tokens": self.tokens, - "disable_optimizations": self.disable_optimizations, - } - class CompletionResult(NamedTuple): log_probs: Optional[Sequence[Mapping[str, Optional[float]]]] = None diff --git a/aleph_alpha_client/detokenization.py b/aleph_alpha_client/detokenization.py index 1db6ad0..382a6bb 100644 --- a/aleph_alpha_client/detokenization.py +++ b/aleph_alpha_client/detokenization.py @@ -2,14 +2,14 @@ class DetokenizationRequest(NamedTuple): + """Describes a detokenization request. + + Parameters + token_ids (Sequence[int]) + Ids of the tokens for which the text should be returned. + """ token_ids: Sequence[int] - def render_as_body(self, model: str) -> Dict[str, Any]: - return { - "model": model, - "token_ids": self.token_ids, - } - class DetokenizationResponse(NamedTuple): result: Sequence[str] diff --git a/aleph_alpha_client/embedding.py b/aleph_alpha_client/embedding.py index d018c64..7aba9f2 100644 --- a/aleph_alpha_client/embedding.py +++ b/aleph_alpha_client/embedding.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from aleph_alpha_client.image import ImagePrompt -from aleph_alpha_client.prompt import _to_prompt_item +from aleph_alpha_client.prompt import Prompt, _to_prompt_item class EmbeddingRequest(NamedTuple): @@ -33,23 +33,12 @@ class EmbeddingRequest(NamedTuple): """ - prompt: Sequence[Union[str, ImagePrompt]] + prompt: Prompt layers: List[int] pooling: List[str] type: Optional[str] = None tokens: bool = False - def render_as_body(self, model: str, hosting=Optional[str]) -> dict: - return { - "model": model, - "hosting": hosting, - "prompt": [_to_prompt_item(item) for item in self.prompt], - "layers": self.layers, - "pooling": self.pooling, - "type": self.type, - "tokens": self.tokens, - } - class EmbeddingResponse(NamedTuple): model_version: str diff --git a/aleph_alpha_client/evaluation.py b/aleph_alpha_client/evaluation.py index 9cc9152..26f7c63 100644 --- a/aleph_alpha_client/evaluation.py +++ b/aleph_alpha_client/evaluation.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union -from aleph_alpha_client.image import ImagePrompt -from aleph_alpha_client.prompt import _to_serializable_prompt +from aleph_alpha_client.prompt import Prompt class EvaluationRequest(NamedTuple): @@ -15,17 +14,9 @@ class EvaluationRequest(NamedTuple): The ground truth completion expected to be produced given the prompt. """ - prompt: Sequence[Union[str, ImagePrompt]] + prompt: Prompt completion_expected: str - def render_as_body(self, model: str, hosting=Optional[str]) -> dict: - return { - "model": model, - "hosting": hosting, - "prompt": _to_serializable_prompt(self.prompt), - "completion_expected": self.completion_expected, - } - class EvaluationResponse(NamedTuple): model_version: str diff --git a/aleph_alpha_client/explanation.py b/aleph_alpha_client/explanation.py index bdf84fe..fdf89db 100644 --- a/aleph_alpha_client/explanation.py +++ b/aleph_alpha_client/explanation.py @@ -1,22 +1,10 @@ from typing import List, NamedTuple, Optional, Union -from aleph_alpha_client.image import ImagePrompt -from aleph_alpha_client.prompt import _to_prompt_item +from aleph_alpha_client.prompt import Prompt class ExplanationRequest(NamedTuple): - prompt: List[Union[str, ImagePrompt]] + prompt: Prompt target: str directional: bool suppression_factor: float conceptual_suppression_threshold: Optional[float] = None - - - def render_as_body(self, model: str, hosting=Optional[str]) -> dict: - return { - "model": model, - "prompt": [_to_prompt_item(item) for item in self.prompt], - "target": self.target, - "suppression_factor": self.suppression_factor, - "directional": self.directional, - "conceptual_suppression_threshold": self.conceptual_suppression_threshold - } \ No newline at end of file diff --git a/aleph_alpha_client/prompt.py b/aleph_alpha_client/prompt.py index ed4c8c6..07ef621 100644 --- a/aleph_alpha_client/prompt.py +++ b/aleph_alpha_client/prompt.py @@ -10,6 +10,11 @@ class Prompt(NamedTuple): def from_text(text: str) -> "Prompt": return Prompt([text]) + @staticmethod + def from_image(image: ImagePrompt) -> "Prompt": + return Prompt([image]) + + def _to_prompt_item(item: Union[str, ImagePrompt]) -> Dict[str, str]: if isinstance(item, str): return {"type": "text", "data": item} diff --git a/aleph_alpha_client/qa.py b/aleph_alpha_client/qa.py index b7b789a..65873f4 100644 --- a/aleph_alpha_client/qa.py +++ b/aleph_alpha_client/qa.py @@ -51,23 +51,6 @@ class QaRequest(NamedTuple): max_answers: int = 0 min_score: float = 0.0 - def render_as_body(self, model: str, hosting: str): - serialized_documents = [ - document._to_serializable_document() for document in self.documents - ] - - return { - "model": model, - "hosting": hosting, - "query": self.query, - "documents": serialized_documents, - "maximum_tokens": self.maximum_tokens, - "max_answers": self.max_answers, - "min_score": self.min_score, - "max_chunk_size": self.max_chunk_size, - "disable_optimizations": self.disable_optimizations, - } - class QaAnswer(NamedTuple): answer: str diff --git a/aleph_alpha_client/tokenization.py b/aleph_alpha_client/tokenization.py index 87b6ea9..38aeed3 100644 --- a/aleph_alpha_client/tokenization.py +++ b/aleph_alpha_client/tokenization.py @@ -2,17 +2,25 @@ class TokenizationRequest(NamedTuple): + """Describes a tokenization request. + + Parameters + prompt (str) + The text prompt which should be converted into tokens + + tokens (bool) + True to extract text-tokens + + token_ids (bool) + True to extract token-ids + + Returns + TokenizationResponse + """ prompt: str tokens: bool token_ids: bool - def render_as_body(self, model: str) -> Dict[str, Any]: - return { - "model": model, - "prompt": self.prompt, - "tokens": self.tokens, - "token_ids": self.token_ids, - } class TokenizationResponse(NamedTuple): tokens: Optional[Sequence[str]] = None diff --git a/readme.ipynb b/readme.ipynb index 3228837..b8ae0b3 100644 --- a/readme.ipynb +++ b/readme.ipynb @@ -31,7 +31,7 @@ "metadata": {}, "outputs": [], "source": [ - "from aleph_alpha_client import ImagePrompt, AlephAlphaModel, AlephAlphaClient, CompletionRequest\n", + "from aleph_alpha_client import ImagePrompt, AlephAlphaModel, AlephAlphaClient, CompletionRequest, Prompt\n", "import os\n", "\n", "model = AlephAlphaModel(\n", @@ -43,14 +43,14 @@ "url = \"https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png\"\n", "\n", "image = ImagePrompt.from_url(url)\n", - "prompt = [\n", + "prompt = Prompt([\n", " image,\n", " \"Q: What does the picture show? A:\",\n", - "]\n", + "])\n", "request = CompletionRequest(prompt=prompt, maximum_tokens=20)\n", "result = model.complete(request)\n", "\n", - "print(result.completions[0][\"completion\"])" + "print(result.completions[0].completion)" ] }, { @@ -67,7 +67,7 @@ "metadata": {}, "outputs": [], "source": [ - "from aleph_alpha_client import AlephAlphaClient, AlephAlphaModel, EvaluationRequest\n", + "from aleph_alpha_client import AlephAlphaClient, AlephAlphaModel, EvaluationRequest, Prompt\n", "import os\n", "\n", "model = AlephAlphaModel(\n", @@ -75,7 +75,7 @@ " model_name = \"luminous-extended\"\n", ")\n", "\n", - "request = EvaluationRequest(prompt=\"The api works\", completion_expected=\" well\")\n", + "request = EvaluationRequest(prompt=Prompt.from_text(\"The api works\"), completion_expected=\" well\")\n", "result = model.evaluate(request)\n", "\n", "print(result)\n" @@ -95,7 +95,7 @@ "metadata": {}, "outputs": [], "source": [ - "from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EvaluationRequest\n", + "from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EvaluationRequest, Prompt\n", "import os\n", "\n", "model = AlephAlphaModel(\n", @@ -106,10 +106,10 @@ "\n", "url = \"https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/2008-09-24_Blockbuster_in_Durham.jpg/330px-2008-09-24_Blockbuster_in_Durham.jpg\"\n", "image = ImagePrompt.from_url(url)\n", - "prompt = [\n", + "prompt = Prompt([\n", " image,\n", " \"Q: What is the name of the store?\\nA:\",\n", - "]\n", + "])\n", "request = EvaluationRequest(prompt=prompt, completion_expected=\" Blockbuster Video\")\n", "result = model.evaluate(request)\n", "\n", @@ -130,7 +130,7 @@ "metadata": {}, "outputs": [], "source": [ - "from aleph_alpha_client import AlephAlphaModel, AlephAlphaClient, EmbeddingRequest\n", + "from aleph_alpha_client import AlephAlphaModel, AlephAlphaClient, EmbeddingRequest, Prompt\n", "import os\n", "\n", "model = AlephAlphaModel(\n", @@ -138,7 +138,7 @@ " model_name = \"luminous-extended\"\n", ")\n", "\n", - "request = EmbeddingRequest(prompt=[\"This is an example.\"], layers=[-1], pooling=[\"mean\"])\n", + "request = EmbeddingRequest(prompt=Prompt.from_text(\"This is an example.\"), layers=[-1], pooling=[\"mean\"])\n", "result = model.embed(request)\n", "\n", "print(result)" @@ -158,7 +158,7 @@ "metadata": {}, "outputs": [], "source": [ - "from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EmbeddingRequest\n", + "from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EmbeddingRequest, Prompt\n", "import os\n", "\n", "model = AlephAlphaModel(\n", @@ -169,10 +169,10 @@ "\n", "url = \"https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/2008-09-24_Blockbuster_in_Durham.jpg/330px-2008-09-24_Blockbuster_in_Durham.jpg\"\n", "image = ImagePrompt.from_url(url)\n", - "prompt = [\n", + "prompt = Prompt([\n", " image,\n", " \"Q: What is the name of the store?\\nA:\",\n", - "]\n", + "])\n", "request = EmbeddingRequest(prompt=prompt, layers=[-1], pooling=[\"mean\"])\n", "result = model.embed(request)\n", "\n", diff --git a/tests/test_complete.py b/tests/test_complete.py index cc5d1f3..140ef62 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -4,13 +4,14 @@ from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel from aleph_alpha_client.completion import CompletionRequest +from aleph_alpha_client.prompt import Prompt from tests.common import client, model_name, model def test_complete(model: AlephAlphaModel): request = CompletionRequest( - prompt="", + prompt=Prompt.from_text(""), maximum_tokens=7, tokens=False, log_probs=0, @@ -38,7 +39,7 @@ def test_complete_fails(model: AlephAlphaModel): # when posting an illegal request request = CompletionRequest( - prompt="", + prompt=Prompt.from_text(""), maximum_tokens=-1, tokens=False, log_probs=0, diff --git a/tests/test_embed.py b/tests/test_embed.py index b4090a2..d03cc3d 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -2,13 +2,14 @@ import pytest from aleph_alpha_client import AlephAlphaClient, EmbeddingRequest from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel +from aleph_alpha_client.prompt import Prompt from tests.common import client, model_name, model def test_embed(model: AlephAlphaModel): request = EmbeddingRequest( - prompt=["hello"], layers=[0, -1], pooling=["mean", "max"] + prompt=Prompt.from_text("hello"), layers=[0, -1], pooling=["mean", "max"] ) result = model.embed(request=request) @@ -33,9 +34,7 @@ def test_embed_with_client(client: AlephAlphaClient, model_name: str): def test_embedding_of_one_token_aggregates_identically(model: AlephAlphaModel): request = EmbeddingRequest( - prompt=[ - "hello" - ], # it is important for this test that we only embed one single token + prompt=Prompt.from_text("hello"), # it is important for this test that we only embed one single token layers=[0, -1], pooling=["mean", "max"], ) @@ -49,7 +48,7 @@ def test_embedding_of_one_token_aggregates_identically(model: AlephAlphaModel): def test_embed_with_tokens(model: AlephAlphaModel): request = EmbeddingRequest( - prompt=["abc"], layers=[-1], pooling=["mean"], tokens=True + prompt=Prompt.from_text("abc"), layers=[-1], pooling=["mean"], tokens=True ) result = model.embed(request) @@ -64,7 +63,7 @@ def test_failing_embedding_request(model: AlephAlphaModel): assert model.model_name in (model["name"] for model in model.client.available_models()) # when posting an illegal request - request = EmbeddingRequest(prompt=["abc"], layers=[0, 1, 2], pooling=["mean"]) + request = EmbeddingRequest(prompt=Prompt.from_text("abc"), layers=[0, 1, 2], pooling=["mean"]) # then we expect an exception tue to a bad request response from the API with pytest.raises(ValueError) as e: diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 7a6648a..6aacbd6 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -4,12 +4,13 @@ from aleph_alpha_client import AlephAlphaClient from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel from aleph_alpha_client.evaluation import EvaluationRequest +from aleph_alpha_client.prompt import Prompt from tests.common import client, model_name, model def test_evaluate(model: AlephAlphaModel): - request = EvaluationRequest(prompt=["hello"], completion_expected="world") + request = EvaluationRequest(prompt=Prompt.from_text("hello"), completion_expected="world") result = model.evaluate(request) @@ -30,7 +31,7 @@ def test_evaluate_fails(model: AlephAlphaModel): # when posting an illegal request request = EvaluationRequest( - prompt=["hello"], + prompt=Prompt.from_text("hello"), completion_expected="", ) diff --git a/tests/test_explanation.py b/tests/test_explanation.py index 30ae62a..d40fb52 100644 --- a/tests/test_explanation.py +++ b/tests/test_explanation.py @@ -1,6 +1,7 @@ import pytest -from aleph_alpha_client import AlephAlphaClient, ExplanationRequest +from aleph_alpha_client import ExplanationRequest from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel +from aleph_alpha_client.prompt import Prompt from tests.common import client, model_name, model @@ -8,7 +9,7 @@ def test_explanation(model: AlephAlphaModel): request = ExplanationRequest( - prompt=["An apple a day"], + prompt=Prompt.from_text("An apple a day"), target=" keeps the doctor away", directional=False, suppression_factor=0.1, @@ -26,7 +27,7 @@ def test_explain_fails(model: AlephAlphaModel): # when posting an illegal request request = ExplanationRequest( - prompt=["An apple a day"], + prompt=Prompt.from_text("An apple a day"), target=" keeps the doctor away", directional=False, suppression_factor=0.1,