Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix/revert resilient streaming (#109)
Browse files Browse the repository at this point in the history
* Revert "Add resilient streaming (#107)"

This reverts commit f2e4d88.

* ft: update changelog

* ft: unreverts some testing fixes

* fix: pre-commit formatting
  • Loading branch information
urimandujano authored Feb 1, 2024
1 parent 675f6d7 commit 4c176fc
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 347 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Reverting [#107](https://github.com/PrefectHQ/prefect-kubernetes/pull/107) to address
deadlocking issue.

### Security

## 0.3.3
Expand Down
9 changes: 2 additions & 7 deletions prefect_kubernetes/events.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import atexit
import logging
import threading
from typing import TYPE_CHECKING, Dict, List, Optional

from prefect.events import Event, RelatedResource, emit_event
from prefect.utilities.importtools import lazy_import

from prefect_kubernetes.utilities import ResilientStreamWatcher

if TYPE_CHECKING:
import kubernetes
import kubernetes.client
Expand Down Expand Up @@ -41,13 +38,11 @@ def __init__(
worker_resource: Dict[str, str],
related_resources: List[RelatedResource],
timeout_seconds: int,
logger: Optional[logging.Logger] = None,
):
self._client = client
self._job_name = job_name
self._namespace = namespace
self._timeout_seconds = timeout_seconds
self._logger = logger

# All events emitted by this replicator have the pod itself as the
# resource. The `worker_resource` is what the worker uses when it's
Expand All @@ -57,7 +52,7 @@ def __init__(
worker_related_resource = RelatedResource(__root__=worker_resource)
self._related_resources = related_resources + [worker_related_resource]

self._watch = ResilientStreamWatcher(logger=self._logger)
self._watch = kubernetes.watch.Watch()
self._thread = threading.Thread(target=self._replicate_pod_events)

self._state = "READY"
Expand Down Expand Up @@ -95,7 +90,7 @@ def _replicate_pod_events(self):

try:
core_client = kubernetes.client.CoreV1Api(api_client=self._client)
for event in self._watch.api_object_stream(
for event in self._watch.stream(
func=core_client.list_namespaced_pod,
namespace=self._namespace,
label_selector=f"job-name={self._job_name}",
Expand Down
12 changes: 9 additions & 3 deletions prefect_kubernetes/pods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, Callable, Dict, Optional, Union

from kubernetes.client.models import V1DeleteOptions, V1Pod, V1PodList
from kubernetes.watch import Watch
from prefect import task
from prefect.utilities.asyncutils import run_sync_in_worker_thread

from prefect_kubernetes.credentials import KubernetesCredentials
from prefect_kubernetes.utilities import ResilientStreamWatcher


@task
Expand Down Expand Up @@ -45,6 +45,7 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

return await run_sync_in_worker_thread(
core_v1_client.create_namespaced_pod,
namespace=namespace,
Expand Down Expand Up @@ -92,6 +93,7 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

return await run_sync_in_worker_thread(
core_v1_client.delete_namespaced_pod,
pod_name,
Expand Down Expand Up @@ -133,6 +135,7 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

return await run_sync_in_worker_thread(
core_v1_client.list_namespaced_pod, namespace=namespace, **kube_kwargs
)
Expand Down Expand Up @@ -177,6 +180,7 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

return await run_sync_in_worker_thread(
core_v1_client.patch_namespaced_pod,
name=pod_name,
Expand Down Expand Up @@ -220,6 +224,7 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

return await run_sync_in_worker_thread(
core_v1_client.read_namespaced_pod,
name=pod_name,
Expand Down Expand Up @@ -276,11 +281,11 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

if print_func is not None:
# should no longer need to manually refresh on ApiException.status == 410
# as of https://github.com/kubernetes-client/python-base/pull/133
watcher = ResilientStreamWatcher()
for log_line in watcher.stream(
for log_line in Watch().stream(
core_v1_client.read_namespaced_pod_log,
name=pod_name,
namespace=namespace,
Expand Down Expand Up @@ -336,6 +341,7 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("core") as core_v1_client:

return await run_sync_in_worker_thread(
core_v1_client.replace_namespaced_pod,
body=new_pod,
Expand Down
144 changes: 1 addition & 143 deletions prefect_kubernetes/utilities.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
""" Utilities for working with the Python Kubernetes API. """
import logging
import socket
import sys
import time
from pathlib import Path
from typing import Callable, List, Optional, Set, Type, TypeVar, Union
from typing import Optional, TypeVar, Union

import urllib3
from kubernetes import watch
from kubernetes.client import ApiClient
from kubernetes.client import models as k8s_models
from prefect.infrastructure.kubernetes import KubernetesJob, KubernetesManifest
Expand All @@ -20,24 +16,6 @@
V1KubernetesModel = TypeVar("V1KubernetesModel")


class _CappedSet(set):
"""
A set with a bounded size.
"""

def __init__(self, maxsize):
super().__init__()
self.maxsize = maxsize

def add(self, value):
"""
Add to the set and maintain its max size.
"""
if len(self) >= self.maxsize:
self.pop()
super().add(value)


def enable_socket_keep_alive(client: ApiClient) -> None:
"""
Setting the keep-alive flags on the kubernetes client object.
Expand Down Expand Up @@ -235,123 +213,3 @@ def _slugify_label_value(value: str, max_length: int = 63) -> str:
# Kubernetes to throw the validation error

return slug


class ResilientStreamWatcher:
"""
A wrapper class around kuberenetes.watch.Watch that will reconnect on
certain exceptions.
"""

DEFAULT_RECONNECT_EXCEPTIONS = (urllib3.exceptions.ProtocolError,)

def __init__(
self,
logger: Optional[logging.Logger] = None,
max_cache_size: int = 50000,
reconnect_exceptions: Optional[List[Type[Exception]]] = None,
) -> None:
"""
A utility class for managing streams of Kuberenetes API objects and logs
Attributes:
logger: A logger which will be used interally to log errors
max_cache_size: The maximum number of API objects to track in an
internal cache to help deduplicate results on stream reconnects
reconnect_exceptions: A list of exceptions that will cause the stream
to reconnect.
"""

self.max_cache_size = max_cache_size
self.logger = logger
self.watch = watch.Watch()

reconnect_exceptions = (
reconnect_exceptions
if reconnect_exceptions is not None
else self.DEFAULT_RECONNECT_EXCEPTIONS
)
self.reconnect_exceptions = tuple(reconnect_exceptions)

def stream(self, func: Callable, *args, cache: Optional[Set] = None, **kwargs):
"""
A method for streaming API objects or logs from a Kubernetes
client function. This method will reconnect the stream on certain
configurable exceptions and deduplicate results on reconnects if
streaming API objects and a cache is provided.
Note that client functions that produce a stream will
restart a stream from the beginning of the log's history on reconnect.
If a cache is not provided, it is possible for duplicate entries to be yielded.
Args:
func: A Kubernetes client function to call which produces a stream
of logs
*args: Positional arguments to pass to `func`
cache: A keyward argument that provides a way to deduplicate
results on reconnects and bound
**kwargs: Keyword arguments to pass to `func`
Returns:
An iterator of log
"""
keep_streaming = True
while keep_streaming:
try:
for event in self.watch.stream(func, *args, **kwargs):
# check that we want to and can track this object
if (
cache is not None
and isinstance(event, dict)
and "object" in event
):
uid = event["object"].metadata.uid
if uid not in cache:
cache.add(uid)
yield event
else:
yield event
else:
# Case: we've finished iterating
keep_streaming = False
except self.reconnect_exceptions:
# Case: We've hit an exception we're willing to retry on
if self.logger:
self.logger.error("Unable to connect, retrying...", exc_info=True)
time.sleep(1)
except Exception:
# Case: We hit an exception we're unwilling to retry on
if self.logger:
self.logger.exception(
f"Unexpected error while streaming {func.__name__}"
)
keep_streaming = False
self.stop()
raise

self.stop()

def api_object_stream(self, func: Callable, *args, **kwargs):
"""
Create a cache to maintain a record of API objects that have been
seen. This is useful because `stream` will reconnect a stream on
`self.reconnect_exceptions` and on reconnect it will restart streaming all
objects. This cache prevents the same object from being yielded twice.
Args:
func: A Kubernetes client function to call which produces a stream of API o
bjects
*args: Positional arguments to pass to `func`
**kwargs: Keyword arguments to pass to `func`
Returns:
An iterator of API objects
"""
cache = _CappedSet(self.max_cache_size)
yield from self.stream(func, *args, cache=cache, **kwargs)

def stop(self):
"""
Shut down the internal Watch object.
"""
self.watch.stop()
29 changes: 14 additions & 15 deletions prefect_kubernetes/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@

from prefect_kubernetes.events import KubernetesEventsReplicator
from prefect_kubernetes.utilities import (
ResilientStreamWatcher,
_slugify_label_key,
_slugify_label_value,
_slugify_name,
Expand Down Expand Up @@ -577,6 +576,7 @@ async def run(
task_status.started(pid)

# Monitor the job until completion

events_replicator = KubernetesEventsReplicator(
client=client,
job_name=job.metadata.name,
Expand All @@ -586,7 +586,6 @@ async def run(
configuration=configuration
),
timeout_seconds=configuration.pod_watch_timeout_seconds,
logger=logger,
)

with events_replicator:
Expand Down Expand Up @@ -919,16 +918,15 @@ def _watch_job(

if configuration.stream_output:
with self._get_core_client(client) as core_client:
watch = ResilientStreamWatcher(logger=logger)
logs = core_client.read_namespaced_pod_log(
pod.metadata.name,
configuration.namespace,
follow=True,
_preload_content=False,
container="prefect-job",
)
try:
for log in watch.stream(
core_client.read_namespaced_pod_log,
pod.metadata.name,
configuration.namespace,
follow=True,
_preload_content=False,
container="prefect-job",
):
for log in logs.stream():
print(log.decode().rstrip())

# Check if we have passed the deadline and should stop streaming
Expand All @@ -938,6 +936,7 @@ def _watch_job(
)
if deadline and remaining_time <= 0:
break

except Exception:
logger.warning(
(
Expand Down Expand Up @@ -966,15 +965,15 @@ def _watch_job(
)
return -1

watch = ResilientStreamWatcher(logger=logger)
watch = kubernetes.watch.Watch()
# The kubernetes library will disable retries if the timeout kwarg is
# present regardless of the value so we do not pass it unless given
# https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160
timeout_seconds = (
{"timeout_seconds": remaining_time} if deadline else {}
)

for event in watch.api_object_stream(
for event in watch.stream(
func=batch_client.list_namespaced_job,
field_selector=f"metadata.name={job_name}",
namespace=configuration.namespace,
Expand Down Expand Up @@ -1074,12 +1073,12 @@ def _get_job_pod(
"""Get the first running pod for a job."""
from kubernetes.client.models import V1Pod

watch = ResilientStreamWatcher(logger=logger)
watch = kubernetes.watch.Watch()
logger.debug(f"Job {job_name!r}: Starting watch for pod start...")
last_phase = None
last_pod_name: Optional[str] = None
with self._get_core_client(client) as core_client:
for event in watch.api_object_stream(
for event in watch.stream(
func=core_client.list_namespaced_pod,
namespace=configuration.namespace,
label_selector=f"job-name={job_name}",
Expand Down
Loading

0 comments on commit 4c176fc

Please sign in to comment.