Skip to content

Commit

Permalink
Method to check version match
Browse files Browse the repository at this point in the history
  • Loading branch information
WieslerTNG committed Jan 11, 2024
1 parent db972b8 commit c2b6b2a
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
20 changes: 20 additions & 0 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings

from packaging import version
from tokenizers import Tokenizer # type: ignore
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -48,6 +50,7 @@
SemanticEmbeddingRequest,
SemanticEmbeddingResponse,
)
from aleph_alpha_client.version import MIN_API_VERSION

POOLING_OPTIONS = ["mean", "max", "last_token", "abs_max"]
RETRY_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
Expand Down Expand Up @@ -80,6 +83,16 @@ def _raise_for_status(status_code: int, text: str):
raise RuntimeError(status_code, text)


def _check_api_version(version_str: str):
api_ver = version.parse(MIN_API_VERSION)
ver = version.parse(version_str)
valid = api_ver.major == ver.major and api_ver <= ver
if not valid:
raise RuntimeError(
f"The aleph alpha client requires at least api version {api_ver}, found version {ver}"
)


AnyRequest = Union[
CompletionRequest,
EmbeddingRequest,
Expand Down Expand Up @@ -179,6 +192,10 @@ def __init__(
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)

def validate_version(self) -> None:
"""Gets version of the AlephAlpha HTTP API."""
_check_api_version(self.get_version())

def get_version(self) -> str:
"""Gets version of the AlephAlpha HTTP API."""
return self._get_request("version").text
Expand Down Expand Up @@ -687,6 +704,9 @@ async def __aexit__(
):
await self.session.__aexit__(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

async def validate_version(self) -> None:
_check_api_version(await self.get_version())

async def get_version(self) -> str:
"""Gets version of the AlephAlpha HTTP API."""
return await self._get_request_text("version")
Expand Down
1 change: 1 addition & 0 deletions aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__version__ = "5.0.0"
MIN_API_VERSION = "1.15.0"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def version():
"Pillow >= 9.2.0",
"tqdm >= v4.62.0",
"python-liquid >= 1.9.4",
"packaging >= 23.2"
],
tests_require=tests_require,
extras_require={
Expand Down
28 changes: 28 additions & 0 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pytest_httpserver import HTTPServer
import os
import pytest

from aleph_alpha_client.version import MIN_API_VERSION
from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.completion import (
CompletionRequest,
Expand All @@ -11,6 +13,32 @@
from tests.common import model_name, sync_client, async_client


def test_api_version_mismatch_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("0.0.0")

with pytest.raises(RuntimeError):
Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version()


async def test_api_version_mismatch_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("0.0.0")

with pytest.raises(RuntimeError):
async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client:
await client.validate_version()


def test_api_version_correct_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)
Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version()


async def test_api_version_correct_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)
async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client:
await client.validate_version()


@pytest.mark.system_test
async def test_can_use_async_client_without_context_manager(model_name: str):
request = CompletionRequest(
Expand Down

0 comments on commit c2b6b2a

Please sign in to comment.