From df19c1ab0707d2004c66cdbbfe011465256fb307 Mon Sep 17 00:00:00 2001 From: Uriel Mandujano Date: Tue, 30 Jan 2024 17:26:36 -0600 Subject: [PATCH] Revert "Add resilient streaming (#107)" This reverts commit f2e4d8893fd5e66d42ad8cbc1ad91f93a5c3ed8c. --- prefect_kubernetes/events.py | 9 +- prefect_kubernetes/pods.py | 12 ++- prefect_kubernetes/utilities.py | 148 +------------------------------- prefect_kubernetes/worker.py | 29 +++---- tests/conftest.py | 12 +-- tests/test_events_replicator.py | 15 ++-- tests/test_flows.py | 2 + tests/test_utilities.py | 114 ------------------------ tests/test_worker.py | 109 +++++++++++------------ 9 files changed, 87 insertions(+), 363 deletions(-) diff --git a/prefect_kubernetes/events.py b/prefect_kubernetes/events.py index 18ff4e9..3f22ed7 100644 --- a/prefect_kubernetes/events.py +++ b/prefect_kubernetes/events.py @@ -1,13 +1,10 @@ import atexit -import logging import threading from typing import TYPE_CHECKING, Dict, List, Optional from prefect.events import Event, RelatedResource, emit_event from prefect.utilities.importtools import lazy_import -from prefect_kubernetes.utilities import ResilientStreamWatcher - if TYPE_CHECKING: import kubernetes import kubernetes.client @@ -41,13 +38,11 @@ def __init__( worker_resource: Dict[str, str], related_resources: List[RelatedResource], timeout_seconds: int, - logger: Optional[logging.Logger] = None, ): self._client = client self._job_name = job_name self._namespace = namespace self._timeout_seconds = timeout_seconds - self._logger = logger # All events emitted by this replicator have the pod itself as the # resource. The `worker_resource` is what the worker uses when it's @@ -57,7 +52,7 @@ def __init__( worker_related_resource = RelatedResource(__root__=worker_resource) self._related_resources = related_resources + [worker_related_resource] - self._watch = ResilientStreamWatcher(logger=self._logger) + self._watch = kubernetes.watch.Watch() self._thread = threading.Thread(target=self._replicate_pod_events) self._state = "READY" @@ -95,7 +90,7 @@ def _replicate_pod_events(self): try: core_client = kubernetes.client.CoreV1Api(api_client=self._client) - for event in self._watch.api_object_stream( + for event in self._watch.stream( func=core_client.list_namespaced_pod, namespace=self._namespace, label_selector=f"job-name={self._job_name}", diff --git a/prefect_kubernetes/pods.py b/prefect_kubernetes/pods.py index a33a981..ab8999b 100644 --- a/prefect_kubernetes/pods.py +++ b/prefect_kubernetes/pods.py @@ -2,11 +2,11 @@ from typing import Any, Callable, Dict, Optional, Union from kubernetes.client.models import V1DeleteOptions, V1Pod, V1PodList +from kubernetes.watch import Watch from prefect import task from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_kubernetes.credentials import KubernetesCredentials -from prefect_kubernetes.utilities import ResilientStreamWatcher @task @@ -45,6 +45,7 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + return await run_sync_in_worker_thread( core_v1_client.create_namespaced_pod, namespace=namespace, @@ -92,6 +93,7 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + return await run_sync_in_worker_thread( core_v1_client.delete_namespaced_pod, pod_name, @@ -133,6 +135,7 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + return await run_sync_in_worker_thread( core_v1_client.list_namespaced_pod, namespace=namespace, **kube_kwargs ) @@ -177,6 +180,7 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + return await run_sync_in_worker_thread( core_v1_client.patch_namespaced_pod, name=pod_name, @@ -220,6 +224,7 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + return await run_sync_in_worker_thread( core_v1_client.read_namespaced_pod, name=pod_name, @@ -276,11 +281,11 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + if print_func is not None: # should no longer need to manually refresh on ApiException.status == 410 # as of https://github.com/kubernetes-client/python-base/pull/133 - watcher = ResilientStreamWatcher() - for log_line in watcher.stream( + for log_line in Watch().stream( core_v1_client.read_namespaced_pod_log, name=pod_name, namespace=namespace, @@ -336,6 +341,7 @@ def kubernetes_orchestrator(): ``` """ with kubernetes_credentials.get_client("core") as core_v1_client: + return await run_sync_in_worker_thread( core_v1_client.replace_namespaced_pod, body=new_pod, diff --git a/prefect_kubernetes/utilities.py b/prefect_kubernetes/utilities.py index 741f336..e6b5e04 100644 --- a/prefect_kubernetes/utilities.py +++ b/prefect_kubernetes/utilities.py @@ -1,13 +1,9 @@ """ Utilities for working with the Python Kubernetes API. """ -import logging import socket import sys -import time from pathlib import Path -from typing import Callable, List, Optional, Set, Type, TypeVar, Union +from typing import Optional, TypeVar, Union -import urllib3 -from kubernetes import watch from kubernetes.client import ApiClient from kubernetes.client import models as k8s_models from prefect.infrastructure.kubernetes import KubernetesJob, KubernetesManifest @@ -20,24 +16,6 @@ V1KubernetesModel = TypeVar("V1KubernetesModel") -class _CappedSet(set): - """ - A set with a bounded size. - """ - - def __init__(self, maxsize): - super().__init__() - self.maxsize = maxsize - - def add(self, value): - """ - Add to the set and maintain its max size. - """ - if len(self) >= self.maxsize: - self.pop() - super().add(value) - - def enable_socket_keep_alive(client: ApiClient) -> None: """ Setting the keep-alive flags on the kubernetes client object. @@ -199,9 +177,7 @@ def _slugify_label_key(key: str, max_length: int = 63, prefix_max_length=253) -> prefix, max_length=prefix_max_length, regex_pattern=r"[^a-zA-Z0-9-\.]+", - ).strip( - "_-." - ) # Must start or end with alphanumeric characters + ).strip("_-.") # Must start or end with alphanumeric characters or prefix ) @@ -235,123 +211,3 @@ def _slugify_label_value(value: str, max_length: int = 63) -> str: # Kubernetes to throw the validation error return slug - - -class ResilientStreamWatcher: - """ - A wrapper class around kuberenetes.watch.Watch that will reconnect on - certain exceptions. - """ - - DEFAULT_RECONNECT_EXCEPTIONS = (urllib3.exceptions.ProtocolError,) - - def __init__( - self, - logger: Optional[logging.Logger] = None, - max_cache_size: int = 50000, - reconnect_exceptions: Optional[List[Type[Exception]]] = None, - ) -> None: - """ - A utility class for managing streams of Kuberenetes API objects and logs - - Attributes: - logger: A logger which will be used interally to log errors - max_cache_size: The maximum number of API objects to track in an - internal cache to help deduplicate results on stream reconnects - reconnect_exceptions: A list of exceptions that will cause the stream - to reconnect. - """ - - self.max_cache_size = max_cache_size - self.logger = logger - self.watch = watch.Watch() - - reconnect_exceptions = ( - reconnect_exceptions - if reconnect_exceptions is not None - else self.DEFAULT_RECONNECT_EXCEPTIONS - ) - self.reconnect_exceptions = tuple(reconnect_exceptions) - - def stream(self, func: Callable, *args, cache: Optional[Set] = None, **kwargs): - """ - A method for streaming API objects or logs from a Kubernetes - client function. This method will reconnect the stream on certain - configurable exceptions and deduplicate results on reconnects if - streaming API objects and a cache is provided. - - Note that client functions that produce a stream will - restart a stream from the beginning of the log's history on reconnect. - If a cache is not provided, it is possible for duplicate entries to be yielded. - - Args: - func: A Kubernetes client function to call which produces a stream - of logs - *args: Positional arguments to pass to `func` - cache: A keyward argument that provides a way to deduplicate - results on reconnects and bound - **kwargs: Keyword arguments to pass to `func` - - Returns: - An iterator of log - """ - keep_streaming = True - while keep_streaming: - try: - for event in self.watch.stream(func, *args, **kwargs): - # check that we want to and can track this object - if ( - cache is not None - and isinstance(event, dict) - and "object" in event - ): - uid = event["object"].metadata.uid - if uid not in cache: - cache.add(uid) - yield event - else: - yield event - else: - # Case: we've finished iterating - keep_streaming = False - except self.reconnect_exceptions: - # Case: We've hit an exception we're willing to retry on - if self.logger: - self.logger.error("Unable to connect, retrying...", exc_info=True) - time.sleep(1) - except Exception: - # Case: We hit an exception we're unwilling to retry on - if self.logger: - self.logger.exception( - f"Unexpected error while streaming {func.__name__}" - ) - keep_streaming = False - self.stop() - raise - - self.stop() - - def api_object_stream(self, func: Callable, *args, **kwargs): - """ - Create a cache to maintain a record of API objects that have been - seen. This is useful because `stream` will reconnect a stream on - `self.reconnect_exceptions` and on reconnect it will restart streaming all - objects. This cache prevents the same object from being yielded twice. - - Args: - func: A Kubernetes client function to call which produces a stream of API o - bjects - *args: Positional arguments to pass to `func` - **kwargs: Keyword arguments to pass to `func` - - Returns: - An iterator of API objects - """ - cache = _CappedSet(self.max_cache_size) - yield from self.stream(func, *args, cache=cache, **kwargs) - - def stop(self): - """ - Shut down the internal Watch object. - """ - self.watch.stop() diff --git a/prefect_kubernetes/worker.py b/prefect_kubernetes/worker.py index 6cd46eb..5499c37 100644 --- a/prefect_kubernetes/worker.py +++ b/prefect_kubernetes/worker.py @@ -144,7 +144,6 @@ from prefect_kubernetes.events import KubernetesEventsReplicator from prefect_kubernetes.utilities import ( - ResilientStreamWatcher, _slugify_label_key, _slugify_label_value, _slugify_name, @@ -577,6 +576,7 @@ async def run( task_status.started(pid) # Monitor the job until completion + events_replicator = KubernetesEventsReplicator( client=client, job_name=job.metadata.name, @@ -586,7 +586,6 @@ async def run( configuration=configuration ), timeout_seconds=configuration.pod_watch_timeout_seconds, - logger=logger, ) with events_replicator: @@ -919,16 +918,15 @@ def _watch_job( if configuration.stream_output: with self._get_core_client(client) as core_client: - watch = ResilientStreamWatcher(logger=logger) + logs = core_client.read_namespaced_pod_log( + pod.metadata.name, + configuration.namespace, + follow=True, + _preload_content=False, + container="prefect-job", + ) try: - for log in watch.stream( - core_client.read_namespaced_pod_log, - pod.metadata.name, - configuration.namespace, - follow=True, - _preload_content=False, - container="prefect-job", - ): + for log in logs.stream(): print(log.decode().rstrip()) # Check if we have passed the deadline and should stop streaming @@ -938,6 +936,7 @@ def _watch_job( ) if deadline and remaining_time <= 0: break + except Exception: logger.warning( ( @@ -966,7 +965,7 @@ def _watch_job( ) return -1 - watch = ResilientStreamWatcher(logger=logger) + watch = kubernetes.watch.Watch() # The kubernetes library will disable retries if the timeout kwarg is # present regardless of the value so we do not pass it unless given # https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160 @@ -974,7 +973,7 @@ def _watch_job( {"timeout_seconds": remaining_time} if deadline else {} ) - for event in watch.api_object_stream( + for event in watch.stream( func=batch_client.list_namespaced_job, field_selector=f"metadata.name={job_name}", namespace=configuration.namespace, @@ -1074,12 +1073,12 @@ def _get_job_pod( """Get the first running pod for a job.""" from kubernetes.client.models import V1Pod - watch = ResilientStreamWatcher(logger=logger) + watch = kubernetes.watch.Watch() logger.debug(f"Job {job_name!r}: Starting watch for pod start...") last_phase = None last_pod_name: Optional[str] = None with self._get_core_client(client) as core_client: - for event in watch.api_object_stream( + for event in watch.stream( func=core_client.list_namespaced_pod, namespace=configuration.namespace, label_selector=f"job-name={job_name}", diff --git a/tests/conftest.py b/tests/conftest.py index 72bb7ca..f86a269 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,14 +4,8 @@ import pytest import yaml -from kubernetes.client import ( - ApiException, - AppsV1Api, - BatchV1Api, - CoreV1Api, - CustomObjectsApi, - models, -) +from kubernetes.client import AppsV1Api, BatchV1Api, CoreV1Api, CustomObjectsApi, models +from kubernetes.client.exceptions import ApiException from prefect.blocks.kubernetes import KubernetesClusterConfig from prefect.settings import PREFECT_LOGGING_TO_API_ENABLED, temporary_settings from prefect.testing.utilities import prefect_test_harness @@ -186,7 +180,7 @@ def mock_delete_namespaced_job(monkeypatch): @pytest.fixture def mock_stream_timeout(monkeypatch): monkeypatch.setattr( - "prefect_kubernetes.utilities.watch.Watch.stream", + "kubernetes.watch.Watch.stream", MagicMock(side_effect=ApiException(status=408)), ) diff --git a/tests/test_events_replicator.py b/tests/test_events_replicator.py index b34fccd..96ea3d2 100644 --- a/tests/test_events_replicator.py +++ b/tests/test_events_replicator.py @@ -9,7 +9,6 @@ from prefect.utilities.importtools import lazy_import from prefect_kubernetes.events import EVICTED_REASONS, KubernetesEventsReplicator -from prefect_kubernetes.utilities import ResilientStreamWatcher kubernetes = lazy_import("kubernetes") @@ -170,8 +169,8 @@ def test_lifecycle(replicator): def test_replicate_successful_pod_events(replicator, successful_pod_stream): - mock_watch = MagicMock(spec=ResilientStreamWatcher) - mock_watch.api_object_stream.return_value = successful_pod_stream + mock_watch = MagicMock(spec=kubernetes.watch.Watch) + mock_watch.stream.return_value = successful_pod_stream event_count = 0 @@ -258,12 +257,12 @@ def event(*args, **kwargs): ), ] ) - # mock_watch.stop.assert_called_once_with() + mock_watch.stop.assert_called_once_with() def test_replicate_failed_pod_events(replicator, failed_pod_stream): - mock_watch = MagicMock(spec=ResilientStreamWatcher) - mock_watch.api_object_stream.return_value = failed_pod_stream + mock_watch = MagicMock(spec=kubernetes.watch.Watch) + mock_watch.stream.return_value = failed_pod_stream event_count = 0 @@ -354,8 +353,8 @@ def event(*args, **kwargs): def test_replicate_evicted_pod_events(replicator, evicted_pod_stream): - mock_watch = MagicMock(spec=ResilientStreamWatcher) - mock_watch.api_object_stream.return_value = evicted_pod_stream + mock_watch = MagicMock(spec=kubernetes.watch.Watch) + mock_watch.stream.return_value = evicted_pod_stream event_count = 0 diff --git a/tests/test_flows.py b/tests/test_flows.py index e0f1d3f..0b5b84d 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -37,6 +37,7 @@ async def test_run_namespaced_job_successful( mock_list_namespaced_pod, read_pod_logs, ): + await run_namespaced_job(kubernetes_job=valid_kubernetes_job_block) assert mock_create_namespaced_job.call_count == 1 @@ -83,6 +84,7 @@ async def test_run_namespaced_job_unsuccessful( mock_list_namespaced_pod, read_pod_logs, ): + successful_job_status.status.failed = 1 successful_job_status.status.succeeded = None mock_read_namespaced_job_status.return_value = successful_job_status diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 1bdcbed..da08efd 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,18 +1,11 @@ -import logging -import uuid -from typing import Type -from unittest import mock from unittest.mock import MagicMock -import kubernetes import pytest -import urllib3 from kubernetes.client import models as k8s_models from kubernetes.config import ConfigException from prefect.infrastructure.kubernetes import KubernetesJob from prefect_kubernetes.utilities import ( - ResilientStreamWatcher, convert_manifest_to_model, enable_socket_keep_alive, ) @@ -236,113 +229,6 @@ def test_bad_model_type_raises(v1_model_name): convert_manifest_to_model(sample_deployment_manifest, v1_model_name) -def test_resilient_streaming_retries_on_configured_errors(caplog): - watcher = ResilientStreamWatcher(logger=logging.getLogger("test")) - - with mock.patch.object( - watcher.watch, - "stream", - side_effect=[ - watcher.reconnect_exceptions[0], - watcher.reconnect_exceptions[0], - ["random_success"], - ], - ) as mocked_stream: - for log in watcher.api_object_stream(str): - assert log == "random_success" - - assert mocked_stream.call_count == 3 - assert "Unable to connect, retrying..." in caplog.text - - -@pytest.mark.parametrize( - "exc", [Exception, TypeError, ValueError, urllib3.exceptions.ProtocolError] -) -def test_resilient_streaming_raises_on_unconfigured_errors( - exc: Type[Exception], caplog -): - watcher = ResilientStreamWatcher( - logger=logging.getLogger("test"), reconnect_exceptions=[] - ) - - with mock.patch.object(watcher.watch, "stream", side_effect=[exc]) as mocked_stream: - with pytest.raises(exc): - for _ in watcher.api_object_stream(str): - pass - - assert mocked_stream.call_count == 1 - assert "Unexpected error" in caplog.text - assert exc.__name__ in caplog.text - - -def _create_api_objects_mocks(n: int = 3): - objects = [] - for _ in range(n): - o = mock.MagicMock(spec=kubernetes.client.V1Pod) - o.metadata = mock.PropertyMock() - o.metadata.uid = uuid.uuid4() - objects.append(o) - return objects - - -def test_resilient_streaming_deduplicates_api_objects_on_reconnects(): - watcher = ResilientStreamWatcher(logger=logging.getLogger("test")) - - object_pool = _create_api_objects_mocks() - thrown_exceptions = 0 - - def my_stream(*args, **kwargs): - """ - Simulate a stream that throws exceptions after yielding the first - object before yielding the rest of the objects. - """ - for o in object_pool: - yield {"object": o} - - nonlocal thrown_exceptions - if thrown_exceptions < 3: - thrown_exceptions += 1 - raise watcher.reconnect_exceptions[0] - - watcher.watch.stream = my_stream - results = [obj for obj in watcher.api_object_stream(str)] - - assert len(object_pool) == len(results) - - -def test_resilient_streaming_pulls_all_logs_on_reconnects(): - watcher = ResilientStreamWatcher(logger=logging.getLogger("test")) - - logs = ["log1", "log2", "log3", "log4"] - thrown_exceptions = 0 - - def my_stream(*args, **kwargs): - """ - Simulate a stream that throws exceptions after yielding the first - object before yielding the rest of the objects. - """ - for log in logs: - yield log - - nonlocal thrown_exceptions - if thrown_exceptions < 3: - thrown_exceptions += 1 - raise watcher.reconnect_exceptions[0] - - watcher.watch.stream = my_stream - results = [obj for obj in watcher.stream(str)] - - assert results == [ - "log1", # Before first exception - "log1", # Before second exception - "log1", # Before third exception - "log1", # No more exceptions from here onward - "log2", - "log3", - "log4", - ] - - def test_keep_alive_updates_socket_options(mock_api_client): enable_socket_keep_alive(mock_api_client) diff --git a/tests/test_worker.py b/tests/test_worker.py index dbf5dd6..708e6a6 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -67,11 +67,7 @@ def mock_watch(monkeypatch): mock = MagicMock() - monkeypatch.setattr( - "prefect_kubernetes.worker.ResilientStreamWatcher", - MagicMock(return_value=mock), - raising=True, - ) + monkeypatch.setattr("kubernetes.watch.Watch", MagicMock(return_value=mock)) return mock @@ -205,7 +201,7 @@ def enable_store_api_key_in_secret(monkeypatch): stream_output=True, ), lambda flow_run, deployment, flow: KubernetesWorkerJobConfiguration( - command="prefect flow-run execute", + command="python -m prefect.engine", env={ **get_current_settings().to_environment_variables(exclude_unset=True), "PREFECT__FLOW_RUN_ID": str(flow_run.id), @@ -263,11 +259,7 @@ def enable_store_api_key_in_secret(monkeypatch): }, ], "image": get_prefect_image_name(), - "args": [ - "prefect", - "flow-run", - "execute", - ], + "args": ["python", "-m", "prefect.engine"], } ], } @@ -486,7 +478,7 @@ def enable_store_api_key_in_secret(monkeypatch): stream_output=True, ), lambda flow_run, deployment, flow: KubernetesWorkerJobConfiguration( - command="prefect flow-run execute", + command="python -m prefect.engine", env={ **get_current_settings().to_environment_variables(exclude_unset=True), "PREFECT__FLOW_RUN_ID": str(flow_run.id), @@ -553,11 +545,7 @@ def enable_store_api_key_in_secret(monkeypatch): }, ], "image": get_prefect_image_name(), - "args": [ - "prefect", - "flow-run", - "execute", - ], + "args": ["python", "-m", "prefect.engine"], } ], } @@ -1223,7 +1211,7 @@ async def test_user_can_supply_a_sidecar_container_and_volume(self, flow_run): # the prefect-job container is still populated assert pod["containers"][0]["name"] == "prefect-job" - assert pod["containers"][0]["args"] == ["prefect", "flow-run", "execute"] + assert pod["containers"][0]["args"] == ["python", "-m", "prefect.engine"] assert pod["containers"][1] == { "name": "my-sidecar", @@ -1251,9 +1239,7 @@ async def test_creates_job_by_building_a_manifest( mock_core_client, mock_watch, ): - mock_watch.api_object_stream.return_value = ( - _mock_pods_stream_that_returns_running_pod() - ) + mock_watch.stream = _mock_pods_stream_that_returns_running_pod default_configuration.prepare_for_flow_run(flow_run) expected_manifest = default_configuration.job_manifest @@ -1373,7 +1359,7 @@ async def test_job_name_creates_valid_name( job_name, clean_name, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod default_configuration.name = job_name default_configuration.prepare_for_flow_run(flow_run) async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -1391,7 +1377,7 @@ async def test_uses_image_variable( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} @@ -1412,7 +1398,7 @@ async def test_can_store_api_key_in_secret( mock_batch_client, enable_store_api_key_in_secret, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod mock_core_client.read_namespaced_secret.side_effect = ApiException(status=404) configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( @@ -1466,7 +1452,7 @@ async def test_store_api_key_in_existing_secret( mock_batch_client, enable_store_api_key_in_secret, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} @@ -1668,7 +1654,7 @@ async def test_allows_image_setting_from_manifest( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod default_configuration.job_manifest["spec"]["template"]["spec"]["containers"][0][ "image" @@ -1690,7 +1676,7 @@ async def test_uses_labels_setting( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), @@ -1807,7 +1793,7 @@ async def test_sanitizes_user_label_keys( given, expected, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"labels": {given: "foo"}}, @@ -1855,7 +1841,7 @@ async def test_sanitizes_user_label_values( given, expected, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), @@ -1878,7 +1864,7 @@ async def test_uses_namespace_setting( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"namespace": "foo"}, @@ -1900,7 +1886,7 @@ async def test_allows_namespace_setting_from_manifest( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod default_configuration.job_manifest["metadata"]["namespace"] = "test" default_configuration.prepare_for_flow_run(flow_run) @@ -1920,7 +1906,7 @@ async def test_uses_service_account_name_setting( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"service_account_name": "foo"}, @@ -1941,7 +1927,7 @@ async def test_uses_finished_job_ttl_setting( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"finished_job_ttl": 123}, @@ -1962,7 +1948,7 @@ async def test_uses_specified_image_pull_policy( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image_pull_policy": "IfNotPresent"}, @@ -1984,7 +1970,7 @@ async def test_defaults_to_incluster_config( mock_cluster_config, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) @@ -2001,7 +1987,7 @@ async def test_uses_cluster_config_if_not_in_cluster( mock_cluster_config, mock_batch_client, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod mock_cluster_config.load_incluster_config.side_effect = ConfigException() async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) @@ -2018,9 +2004,7 @@ async def test_allows_configurable_timeouts_for_pod_and_job_watches( default_configuration: KubernetesWorkerJobConfiguration, flow_run, ): - mock_watch.api_object_stream = Mock( - side_effect=_mock_pods_stream_that_returns_running_pod - ) + mock_watch.stream = Mock(side_effect=_mock_pods_stream_that_returns_running_pod) # The job should not be completed to start mock_batch_client.read_namespaced_job.return_value.status.completion_time = None @@ -2047,7 +2031,7 @@ async def test_allows_configurable_timeouts_for_pod_and_job_watches( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2069,7 +2053,7 @@ async def test_excludes_timeout_from_job_watches_when_null( mock_batch_client, job_timeout, ): - mock_watch.api_object_stream = mock.Mock( + mock_watch.stream = mock.Mock( side_effect=_mock_pods_stream_that_returns_running_pod ) # The job should not be completed to start @@ -2080,7 +2064,7 @@ async def test_excludes_timeout_from_job_watches_when_null( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2105,7 +2089,7 @@ async def test_watches_the_right_namespace( mock_watch, mock_batch_client, ): - mock_watch.api_object_stream = mock.Mock( + mock_watch.stream = mock.Mock( side_effect=_mock_pods_stream_that_returns_running_pod ) # The job should not be completed to start @@ -2116,7 +2100,7 @@ async def test_watches_the_right_namespace( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2141,11 +2125,14 @@ async def test_streaming_pod_logs_timeout_warns( mock_batch_client, caplog, ): - mock_watch.api_object_stream = _mock_pods_stream_that_returns_running_pod + mock_watch.stream = _mock_pods_stream_that_returns_running_pod # The job should not be completed to start mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - mock_watch.stream = MagicMock(side_effect=RuntimeError("something went wrong")) + mock_logs = MagicMock() + mock_logs.stream = MagicMock(side_effect=RuntimeError("something went wrong")) + + mock_core_client.read_namespaced_pod_log = MagicMock(return_value=mock_logs) async with KubernetesWorker(work_pool_name="test") as k8s_worker: with caplog.at_level("WARNING"): @@ -2178,7 +2165,7 @@ def mock_stream(*args, **kwargs): sleep(0.5) yield {"object": job} - mock_watch.api_object_stream.side_effect = mock_stream + mock_watch.stream.side_effect = mock_stream default_configuration.pod_watch_timeout_seconds = 42 default_configuration.job_watch_timeout_seconds = 0 @@ -2219,7 +2206,7 @@ def mock_log_stream(*args, **kwargs): return MagicMock() mock_core_client.read_namespaced_pod_log.side_effect = mock_log_stream - mock_watch.api_object_stream.side_effect = mock_stream + mock_watch.stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 1000 async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -2227,7 +2214,7 @@ def mock_log_stream(*args, **kwargs): assert result.status_code == 1 - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2269,7 +2256,7 @@ def mock_stream(*args, **kwargs): # Yield the job then return exiting the stream # After restarting the watch a few times, we'll report completion job.status.completion_time = ( - None if mock_watch.api_object_stream.call_count < 3 else True + None if mock_watch.stream.call_count < 3 else True ) yield {"object": job} @@ -2278,8 +2265,8 @@ def mock_log_stream(*args, **kwargs): sleep(0.25) yield f"test {i}".encode() - mock_watch.stream = mock_log_stream - mock_watch.api_object_stream.side_effect = mock_stream + mock_core_client.read_namespaced_pod_log.return_value.stream = mock_log_stream + mock_watch.stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 1 @@ -2289,7 +2276,7 @@ def mock_log_stream(*args, **kwargs): # The job should timeout assert result.status_code == -1 - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2333,8 +2320,8 @@ def mock_log_stream(*args, **kwargs): sleep(0.25) yield f"test {i}".encode() - mock_watch.stream = mock_log_stream - mock_watch.api_object_stream.side_effect = mock_stream + mock_core_client.read_namespaced_pod_log.return_value.stream = mock_log_stream + mock_watch.stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 1 async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -2343,7 +2330,7 @@ def mock_log_stream(*args, **kwargs): # The job should not timeout assert result.status_code == 1 - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2395,14 +2382,14 @@ def mock_stream(*args, **kwargs): job.spec.backoff_limit = 6 yield {"object": job, "type": "ADDED"} - mock_watch.api_object_stream.side_effect = mock_stream + mock_watch.stream.side_effect = mock_stream default_configuration.job_watch_timeout_seconds = 40 async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) assert result.status_code == -1 - mock_watch.api_object_stream.assert_has_calls( + mock_watch.stream.assert_has_calls( [ mock.call( func=mock_core_client.list_namespaced_pod, @@ -2470,7 +2457,7 @@ def mock_stream(*args, **kwargs): job.status.failed = i yield {"object": job, "type": "ADDED"} - mock_watch.api_object_stream.side_effect = mock_stream + mock_watch.stream.side_effect = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) @@ -2505,7 +2492,7 @@ def mock_stream(*args, **kwargs): job.status.failed = i yield {"object": job, "type": "ADDED"} - mock_watch.api_object_stream.side_effect = mock_stream + mock_watch.stream.side_effect = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) @@ -2550,7 +2537,7 @@ def mock_stream(*args, **kwargs): job.status.failed = i yield {"object": job, "type": "ADDED"} - mock_watch.api_object_stream.side_effect = mock_stream + mock_watch.stream.side_effect = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration)