Skip to content

Commit

Permalink
Add explain request to model
Browse files Browse the repository at this point in the history
  • Loading branch information
volkerstampa committed Jun 27, 2022
1 parent 2003c8c commit ccc00a5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
3 changes: 1 addition & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,7 @@ def _explain(
response = requests.post(
f"{self.host}explain", headers=self.request_headers, json=body
)
response_dict = self._translate_errors(response)
return response_dict
return self._translate_errors(response)

@staticmethod
def _translate_errors(response):
Expand Down
5 changes: 5 additions & 0 deletions aleph_alpha_client/aleph_alpha_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Mapping
from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.detokenization import DetokenizationRequest, DetokenizationResponse
from aleph_alpha_client.embedding import EmbeddingRequest, EmbeddingResponse
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.explanation import ExplanationRequest
from aleph_alpha_client.qa import QaRequest, QaResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse

Expand Down Expand Up @@ -37,3 +39,6 @@ def evaluate(self, request: EvaluationRequest) -> EvaluationResponse:
def qa(self, request: QaRequest) -> QaResponse:
response_json = self.client.qa(model = self.model_name, hosting=self.hosting, **request._asdict())
return QaResponse.from_json(response_json)

def _explain(self, request: ExplanationRequest) -> Mapping[str, Any]:
return self.client._explain(model = self.model_name, hosting=self.hosting, request=request)
17 changes: 7 additions & 10 deletions tests/test_explanation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from aleph_alpha_client import AlephAlphaClient, ExplanationRequest
from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel

from tests.common import client, model_name
from tests.common import client, model_name, model


def test_explanation(client: AlephAlphaClient, model_name: str):
def test_explanation(model: AlephAlphaModel):

request = ExplanationRequest(
prompt=["An apple a day"],
Expand All @@ -13,15 +14,15 @@ def test_explanation(client: AlephAlphaClient, model_name: str):
suppression_factor=0.1,
)

explanation = client._explain(model=model_name, request=request, hosting=None)
explanation = model._explain(request)

# List is true if not None and not empty
assert explanation["result"]


def test_explain_fails(client: AlephAlphaClient, model_name: str):
def test_explain_fails(model: AlephAlphaModel):
# given a client
assert model_name in map(lambda model: model["name"], client.available_models())
assert model.model_name in map(lambda model: model["name"], model.client.available_models())

# when posting an illegal request
request = ExplanationRequest(
Expand All @@ -34,10 +35,6 @@ def test_explain_fails(client: AlephAlphaClient, model_name: str):

# then we expect an exception tue to a bad request response from the API
with pytest.raises(ValueError) as e:
response = client._explain(
model_name,
hosting="cloud",
request=request,
)
response = model._explain(request)

assert e.value.args[0] == 400

0 comments on commit ccc00a5

Please sign in to comment.