diff --git a/platform_reports/config.py b/platform_reports/config.py index 8bbbfb96..33c85828 100644 --- a/platform_reports/config.py +++ b/platform_reports/config.py @@ -29,6 +29,7 @@ class KubeConfig: client_key_path: str | None = None token: str | None = None token_path: str | None = None + token_update_interval_s: int = 300 conn_timeout_s: int = 300 read_timeout_s: int = 100 conn_pool_size: int = 100 diff --git a/platform_reports/kube_client.py b/platform_reports/kube_client.py index 29623be7..5e504cd6 100644 --- a/platform_reports/kube_client.py +++ b/platform_reports/kube_client.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio import enum import logging import ssl from collections.abc import Sequence +from contextlib import suppress from dataclasses import dataclass, field from datetime import datetime from math import ceil @@ -154,6 +156,7 @@ def __init__( self._token = config.token self._trace_configs = trace_configs self._client: aiohttp.ClientSession | None = None + self._token_updater_task: asyncio.Task[None] | None = None def _create_ssl_context(self) -> ssl.SSLContext | None: if self._config.url.scheme != "https": @@ -170,7 +173,7 @@ def _create_ssl_context(self) -> ssl.SSLContext | None: return ssl_context async def __aenter__(self) -> KubeClient: - self._client = await self._create_http_client() + await self._init() return self async def __aexit__( @@ -181,40 +184,45 @@ async def __aexit__( ) -> None: await self.aclose() - async def _create_http_client(self) -> aiohttp.ClientSession: + async def _init(self) -> None: connector = aiohttp.TCPConnector( limit=self._config.conn_pool_size, ssl=self._create_ssl_context() ) - if self._config.auth_type == KubeClientAuthType.TOKEN: - token = self._token - if not token: - assert self._config.token_path is not None - token = Path(self._config.token_path).read_text() - headers = {"Authorization": "Bearer " + token} - else: - headers = {} + if self._config.token_path: + self._token = Path(self._config.token_path).read_text() + self._token_updater_task = asyncio.create_task(self._start_token_updater()) timeout = aiohttp.ClientTimeout( connect=self._config.conn_timeout_s, total=self._config.read_timeout_s ) - return aiohttp.ClientSession( + self._client = aiohttp.ClientSession( connector=connector, timeout=timeout, - headers=headers, trace_configs=self._trace_configs, ) - async def aclose(self) -> None: - assert self._client - await self._client.close() - - async def _reload_http_client(self) -> None: - await self.aclose() - self._token = None - self._client = await self._create_http_client() + async def _start_token_updater(self) -> None: + if not self._config.token_path: + return + while True: + try: + token = Path(self._config.token_path).read_text() + if token != self._token: + self._token = token + logger.info("Kube token was refreshed") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.exception("Failed to update kube token: %s", exc) + await asyncio.sleep(self._config.token_update_interval_s) - async def init_if_needed(self) -> None: - if not self._client or self._client.closed: - self._client = await self._create_http_client() + async def aclose(self) -> None: + if self._client: + await self._client.close() + if self._token_updater_task: + self._token_updater_task.cancel() + with suppress(asyncio.CancelledError): + await self._token_updater_task + self._token_updater_task = None def _get_pods_url(self, namespace: str) -> URL: if namespace: @@ -223,8 +231,6 @@ def _get_pods_url(self, namespace: str) -> URL: @trace async def get_node(self, name: str) -> Node: - await self.init_if_needed() - assert self._client url = self._config.url / "api/v1/nodes" / name payload = await self._request(method="get", url=url) assert payload["kind"] == "Node" @@ -234,8 +240,6 @@ async def get_node(self, name: str) -> Node: async def get_pods( self, namespace: str = "", field_selector: str = "", label_selector: str = "" ) -> Sequence[Pod]: - await self.init_if_needed() - assert self._client params: dict[str, str] = {} if field_selector: params["fieldSelector"] = field_selector @@ -247,22 +251,19 @@ async def get_pods( assert payload["kind"] == "PodList" return [Pod.from_payload(i) for i in payload["items"]] + def _create_headers(self, headers: dict[str, Any] | None = None) -> dict[str, Any]: + headers = dict(headers) if headers else {} + if self._config.auth_type == KubeClientAuthType.TOKEN and self._token: + headers["Authorization"] = "Bearer " + self._token + return headers + async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - await self.init_if_needed() + headers = self._create_headers(kwargs.pop("headers", None)) assert self._client, "client is not initialized" - doing_retry = kwargs.pop("doing_retry", False) - async with self._client.request(*args, **kwargs) as resp: + async with self._client.request(*args, headers=headers, **kwargs) as resp: payload = await resp.json() - try: self._raise_for_status(payload) - except KubeClientUnauthorized: - if doing_retry: - raise - # K8s SA's token might be stale, need to refresh it and retry - await self._reload_http_client() - kwargs["doing_retry"] = True - payload = await self._request(*args, **kwargs) - return payload + return payload def _raise_for_status(self, payload: dict[str, Any]) -> None: kind = payload["kind"] diff --git a/setup.cfg b/setup.cfg index 7c28d29d..7bc9f84d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ platforms = any install_requires = neuro-auth-client==22.6.1 neuro-config-client==23.3.0 - neuro-sdk==23.7.0 + neuro-sdk==23.2.0 neuro-logging==21.12.2 aiohttp==3.8.4 python-jose==3.3.0 diff --git a/tests/integration/test_kube_client.py b/tests/integration/test_kube_client.py index b29d751b..5d2ea677 100644 --- a/tests/integration/test_kube_client.py +++ b/tests/integration/test_kube_client.py @@ -1,7 +1,84 @@ +from __future__ import annotations + +import asyncio +import os +import tempfile +from collections.abc import AsyncIterator, Iterator +from pathlib import Path +from typing import Any + +import aiohttp +import aiohttp.web import pytest +from yarl import URL +from platform_reports.config import KubeClientAuthType, KubeConfig from platform_reports.kube_client import KubeClient, KubeClientError, Node, PodPhase +from .conftest import create_local_app_server + + +class TestKubeClientTokenUpdater: + @pytest.fixture + async def kube_app(self) -> aiohttp.web.Application: + async def _get_pods(request: aiohttp.web.Request) -> aiohttp.web.Response: + auth = request.headers["Authorization"] + token = auth.split()[-1] + app["token"]["value"] = token + return aiohttp.web.json_response({"kind": "PodList", "items": []}) + + app = aiohttp.web.Application() + app["token"] = {"value": ""} + app.router.add_routes( + [aiohttp.web.get("/api/v1/namespaces/default/pods", _get_pods)] + ) + return app + + @pytest.fixture + async def kube_server( + self, kube_app: aiohttp.web.Application, unused_tcp_port_factory: Any + ) -> AsyncIterator[str]: + async with create_local_app_server( + kube_app, port=unused_tcp_port_factory() + ) as address: + yield f"http://{address.host}:{address.port}" + + @pytest.fixture + def kube_token_path(self) -> Iterator[str]: + _, path = tempfile.mkstemp() + Path(path).write_text("token-1") + yield path + os.remove(path) + + @pytest.fixture + async def kube_client( + self, kube_server: str, kube_token_path: str + ) -> AsyncIterator[KubeClient]: + async with KubeClient( + config=KubeConfig( + url=URL(kube_server), + auth_type=KubeClientAuthType.TOKEN, + token_path=kube_token_path, + token_update_interval_s=1, + ) + ) as client: + yield client + + async def test_token_periodically_updated( + self, + kube_app: aiohttp.web.Application, + kube_client: KubeClient, + kube_token_path: str, + ) -> None: + await kube_client.get_pods("default") + assert kube_app["token"]["value"] == "token-1" + + Path(kube_token_path).write_text("token-2") + await asyncio.sleep(2) + + await kube_client.get_pods("default") + assert kube_app["token"]["value"] == "token-2" + class TestKubeClient: async def test_get_node(self, kube_client: KubeClient, kube_node: Node) -> None: