Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Release sockets without caching #131

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 63 additions & 107 deletions prefect_kubernetes/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@
import time
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union

import anyio.abc
Expand All @@ -127,7 +125,6 @@
from prefect.server.schemas.responses import DeploymentResponse
from prefect.utilities.asyncutils import run_sync_in_worker_thread
from prefect.utilities.dockerutils import get_prefect_image_name
from prefect.utilities.hashing import hash_objects
from prefect.utilities.importtools import lazy_import
from prefect.utilities.pydantic import JsonPatch
from prefect.utilities.templating import find_placeholders
Expand All @@ -138,7 +135,6 @@
BaseWorkerResult,
)
from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import BaseModel

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import Field, validator
Expand Down Expand Up @@ -173,64 +169,6 @@
RETRY_MAX_DELAY_JITTER_SECONDS = 3


_LOCK = Lock()


class HashableKubernetesClusterConfig(BaseModel):
"""
A hashable version of the KubernetesClusterConfig class.
Used for caching.
"""

config: dict = Field(
default=..., description="The entire contents of a kubectl config file."
)
context_name: str = Field(
default=..., description="The name of the kubectl context to use."
)

def __hash__(self):
"""Make the config hashable."""
return hash(
(
hash_objects(self.config),
self.context_name,
)
)


@lru_cache(maxsize=8, typed=True)
def _get_configured_kubernetes_client_cached(
cluster_config: Optional[HashableKubernetesClusterConfig] = None,
) -> Any:
"""Returns a configured Kubernetes client."""
with _LOCK:
# if a hard-coded cluster config is provided, use it
if cluster_config:
client = kubernetes.config.new_client_from_config_dict(
config_dict=cluster_config.config,
context=cluster_config.context_name,
)
else:
# If no hard-coded config specified, try to load Kubernetes configuration
# within a cluster. If that doesn't work, try to load the configuration
# from the local environment, allowing any further ConfigExceptions to
# bubble up.
try:
kubernetes.config.load_incluster_config()
config = kubernetes.client.Configuration.get_default_copy()
client = kubernetes.client.ApiClient(configuration=config)
except kubernetes.config.ConfigException:
client = kubernetes.config.new_client_from_config()

if os.environ.get(
"PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE"
).strip().lower() in ("true", "1"):
enable_socket_keep_alive(client)

return client


def _get_default_job_manifest_template() -> Dict[str, Any]:
"""Returns the default job manifest template used by the Kubernetes worker."""
return {
Expand Down Expand Up @@ -710,62 +648,80 @@ def _stop_job(
grace_seconds: int = 30,
):
"""Removes the given Job from the Kubernetes cluster"""
client = self._get_configured_kubernetes_client(configuration)
job_cluster_uid, job_namespace, job_name = self._parse_infrastructure_pid(
infrastructure_pid
)

if job_namespace != configuration.namespace:
raise InfrastructureNotAvailable(
f"Unable to kill job {job_name!r}: The job is running in namespace "
f"{job_namespace!r} but this worker expected jobs to be running in "
f"namespace {configuration.namespace!r} based on the work pool and "
"deployment configuration."
with self._get_configured_kubernetes_client(configuration) as client:
job_cluster_uid, job_namespace, job_name = self._parse_infrastructure_pid(
infrastructure_pid
)

current_cluster_uid = self._get_cluster_uid(client)
if job_cluster_uid != current_cluster_uid:
raise InfrastructureNotAvailable(
f"Unable to kill job {job_name!r}: The job is running on another "
"cluster than the one specified by the infrastructure PID."
)
if job_namespace != configuration.namespace:
raise InfrastructureNotAvailable(
f"Unable to kill job {job_name!r}: The job is running in namespace "
f"{job_namespace!r} but this worker expected jobs to be running in "
f"namespace {configuration.namespace!r} based on the work pool and "
"deployment configuration."
)

with self._get_batch_client(client) as batch_client:
try:
batch_client.delete_namespaced_job(
name=job_name,
namespace=job_namespace,
grace_period_seconds=grace_seconds,
# Foreground propagation deletes dependent objects before deleting
# owner objects. This ensures that the pods are cleaned up before
# the job is marked as deleted.
# See: https://kubernetes.io/docs/concepts/architecture/garbage-collection/#foreground-deletion # noqa
propagation_policy="Foreground",
current_cluster_uid = self._get_cluster_uid(client)
if job_cluster_uid != current_cluster_uid:
raise InfrastructureNotAvailable(
f"Unable to kill job {job_name!r}: The job is running on another "
"cluster than the one specified by the infrastructure PID."
)
except kubernetes.client.exceptions.ApiException as exc:
if exc.status == 404:
raise InfrastructureNotFound(
f"Unable to kill job {job_name!r}: The job was not found."
) from exc
else:
raise

with self._get_batch_client(client) as batch_client:
try:
batch_client.delete_namespaced_job(
name=job_name,
namespace=job_namespace,
grace_period_seconds=grace_seconds,
# Foreground propagation deletes dependent objects before deleting # noqa
# owner objects. This ensures that the pods are cleaned up before # noqa
# the job is marked as deleted.
# See: https://kubernetes.io/docs/concepts/architecture/garbage-collection/#foreground-deletion # noqa
propagation_policy="Foreground",
)
except kubernetes.client.exceptions.ApiException as exc:
if exc.status == 404:
raise InfrastructureNotFound(
f"Unable to kill job {job_name!r}: The job was not found."
) from exc
else:
raise

@contextmanager
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes _get_configured_kubernetes_client work like all the other k8s client getters, where we client.rest_client.pool_manager.clear() on exit, which releases sockets.

def _get_configured_kubernetes_client(
self, configuration: KubernetesWorkerJobConfiguration
) -> "ApiClient":
) -> Generator["ApiClient", None, None]:
"""
Returns a configured Kubernetes client.
"""

cluster_config = None

if configuration.cluster_config:
cluster_config = HashableKubernetesClusterConfig(
config=configuration.cluster_config.config,
context_name=configuration.cluster_config.context_name,
)

return _get_configured_kubernetes_client_cached(cluster_config)
try:
if configuration.cluster_config:
client = kubernetes.config.new_client_from_config_dict(
config_dict=configuration.cluster_config.config,
context=configuration.cluster_config.context_name,
)
else:
# If no hardcoded config specified, try to load Kubernetes configuration
# within a cluster. If that doesn't work, try to load the configuration
# from the local environment, allowing any further ConfigExceptions to
# bubble up.
try:
kubernetes.config.load_incluster_config()
config = kubernetes.client.Configuration.get_default_copy()
client = kubernetes.client.ApiClient(configuration=config)
except kubernetes.config.ConfigException:
client = kubernetes.config.new_client_from_config()

if os.environ.get(
"PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE"
).strip().lower() in ("true", "1"):
enable_socket_keep_alive(client)

yield client
finally:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above comment

client.rest_client.pool_manager.clear()

def _replace_api_key_with_secret(
self, configuration: KubernetesWorkerJobConfiguration, client: "ApiClient"
Expand Down
29 changes: 1 addition & 28 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@

from prefect_kubernetes import KubernetesWorker
from prefect_kubernetes.utilities import _slugify_label_value, _slugify_name
from prefect_kubernetes.worker import (
KubernetesWorkerJobConfiguration,
_get_configured_kubernetes_client_cached,
)
from prefect_kubernetes.worker import KubernetesWorkerJobConfiguration

FAKE_CLUSTER = "fake-cluster"
MOCK_CLUSTER_UID = "1234"
Expand Down Expand Up @@ -2024,7 +2021,6 @@ async def test_defaults_to_incluster_config(
mock_cluster_config,
mock_batch_client,
):
_get_configured_kubernetes_client_cached.cache_clear()
mock_watch.stream = _mock_pods_stream_that_returns_running_pod

async with KubernetesWorker(work_pool_name="test") as k8s_worker:
Expand All @@ -2042,36 +2038,13 @@ async def test_uses_cluster_config_if_not_in_cluster(
mock_cluster_config,
mock_batch_client,
):
_get_configured_kubernetes_client_cached.cache_clear()
mock_watch.stream = _mock_pods_stream_that_returns_running_pod
mock_cluster_config.load_incluster_config.side_effect = ConfigException()
async with KubernetesWorker(work_pool_name="test") as k8s_worker:
await k8s_worker.run(flow_run, default_configuration)

mock_cluster_config.new_client_from_config.assert_called_once()

async def test_get_configured_kubernetes_client_cached(
self,
flow_run,
default_configuration,
mock_core_client,
mock_watch,
mock_cluster_config,
mock_batch_client,
):
_get_configured_kubernetes_client_cached.cache_clear()
mock_watch.stream = _mock_pods_stream_that_returns_running_pod

assert _get_configured_kubernetes_client_cached.cache_info().hits == 0

async with KubernetesWorker(work_pool_name="test") as k8s_worker:
await k8s_worker.run(flow_run, default_configuration)
await k8s_worker.run(flow_run, default_configuration)
await k8s_worker.run(flow_run, default_configuration)

assert _get_configured_kubernetes_client_cached.cache_info().misses == 1
assert _get_configured_kubernetes_client_cached.cache_info().hits == 2

@pytest.mark.parametrize("job_timeout", [24, 100])
async def test_allows_configurable_timeouts_for_pod_and_job_watches(
self,
Expand Down