From 538b77b0982bb88aebbac38e301c5f8d7e92639f Mon Sep 17 00:00:00 2001 From: Kevin Cameron Grismore Date: Thu, 1 Feb 2024 22:02:21 -0600 Subject: [PATCH] add client caching --- prefect_kubernetes/worker.py | 85 +++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 21 deletions(-) diff --git a/prefect_kubernetes/worker.py b/prefect_kubernetes/worker.py index 5499c37..f078362 100644 --- a/prefect_kubernetes/worker.py +++ b/prefect_kubernetes/worker.py @@ -99,6 +99,7 @@ For more information about work pools and workers, checkout out the [Prefect docs](https://docs.prefect.io/concepts/work-pools/). """ + import asyncio import base64 import enum @@ -109,6 +110,8 @@ import time from contextlib import contextmanager from datetime import datetime +from functools import lru_cache +from threading import Lock from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple import anyio.abc @@ -134,6 +137,7 @@ BaseWorkerResult, ) from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import BaseModel if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import Field, validator @@ -148,6 +152,7 @@ _slugify_label_value, _slugify_name, enable_socket_keep_alive, + hash_collection, ) if TYPE_CHECKING: @@ -161,6 +166,59 @@ else: kubernetes = lazy_import("kubernetes") +_LOCK = Lock() + + +class HashableKubernetesClusterConfig(BaseModel): + """ + A hashable version of the KubernetesClusterConfig class. + Used for caching. + """ + + config: Optional[dict[str, Any]] = Field(...) + context_name: str = Field(...) + + def __hash__(self): + """Make the conifg hashable.""" + return hash( + ( + hash_collection(self.config), + self.context_name, + ) + ) + + +@lru_cache(maxsize=8, typed=True) +def _get_cached_kubernetes_client( + cluster_config: Optional[HashableKubernetesClusterConfig] = None, +) -> Any: + "Returns a new Kubernetes client is there is not one cached" + with _LOCK: + # if a hard-coded cluster config is provided, use it + if cluster_config: + client = kubernetes.config.new_client_from_config_dict( + config_dict=cluster_config.config, + context=cluster_config.context_name, + ) + else: + # If no hard-coded config specified, try to load Kubernetes configuration + # within a cluster. If that doesn't work, try to load the configuration + # from the local environment, allowing any further ConfigExceptions to + # bubble up. + try: + kubernetes.config.load_incluster_config() + config = kubernetes.client.Configuration.get_default_copy() + client = kubernetes.client.ApiClient(configuration=config) + except kubernetes.config.ConfigException: + client = kubernetes.config.new_client_from_config() + + if os.environ.get( + "PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE" + ).strip().lower() in ("true", "1"): + enable_socket_keep_alive(client) + + return client + def _get_default_job_manifest_template() -> Dict[str, Any]: """Returns the default job manifest template used by the Kubernetes worker.""" @@ -688,30 +746,15 @@ def _get_configured_kubernetes_client( Returns a configured Kubernetes client. """ - # if a hard-coded cluster config is provided, use it + cluster_config = None + if configuration.cluster_config: - client = kubernetes.config.new_client_from_config_dict( - config_dict=configuration.cluster_config.config, - context=configuration.cluster_config.context_name, + cluster_config = HashableKubernetesClusterConfig( + config=configuration.cluster_config.config, + context_name=configuration.cluster_config.context_name, ) - else: - # If no hard-coded config specified, try to load Kubernetes configuration - # within a cluster. If that doesn't work, try to load the configuration - # from the local environment, allowing any further ConfigExceptions to - # bubble up. - try: - kubernetes.config.load_incluster_config() - config = kubernetes.client.Configuration.get_default_copy() - client = kubernetes.client.ApiClient(configuration=config) - except kubernetes.config.ConfigException: - client = kubernetes.config.new_client_from_config() - if os.environ.get( - "PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE" - ).strip().lower() in ("true", "1"): - enable_socket_keep_alive(client) - - return client + return _get_cached_kubernetes_client(cluster_config) def _replace_api_key_with_secret( self, configuration: KubernetesWorkerJobConfiguration, client: "ApiClient"