diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 753afff..106d00a 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -943,7 +943,7 @@ def _build_json_body( json_body["hosting"] = self.hosting return json_body - def models(self) -> Mapping[str, Any]: + def models(self) -> List[Mapping[str, Any]]: """ Queries all models which are currently available. @@ -1357,7 +1357,9 @@ async def _get_request_text(self, endpoint: str) -> str: _raise_for_status(response.status, await response.text()) return await response.text() - async def _get_request_json(self, endpoint: str) -> Mapping[str, Any]: + async def _get_request_json( + self, endpoint: str + ) -> Union[List[Mapping[str, Any]], Mapping[str, Any]]: async with self.session.get( self.host + endpoint, ) as response: @@ -1401,13 +1403,13 @@ def _build_json_body( json_body["hosting"] = self.hosting return json_body - async def models(self) -> Mapping[str, Any]: + async def models(self) -> List[Mapping[str, Any]]: """ Queries all models which are currently available. For documentation of the response, see https://docs.aleph-alpha.com/api/available-models/ """ - return await self._get_request_json("models_available") + return await self._get_request_json("models_available") # type: ignore async def complete( self, diff --git a/tests/test_clients.py b/tests/test_clients.py index b613731..3327510 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -66,7 +66,7 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer): @pytest.mark.system_test def test_available_models_sync_client(sync_client: Client, model_name: str): models = sync_client.models() - assert model_name in [model["name"] for model in models] + assert model_name in {model["name"] for model in models} @pytest.mark.system_test