Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Handle 410s during Job watch (#129)
Browse files Browse the repository at this point in the history
* handle 410s

* remove unused mock

* restore `watch.stop()`

* watch in loop

* add typehint

* union

* Update prefect_kubernetes/worker.py

Co-authored-by: nate nowack <[email protected]>

* improve test, fix list call

* remove debug prints

* test cleanup

---------

Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
kevingrismore and zzstoatzz authored Mar 21, 2024
1 parent d1a5643 commit cb530ad
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 9 deletions.
52 changes: 43 additions & 9 deletions prefect_kubernetes/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
from datetime import datetime
from functools import lru_cache
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union

import anyio.abc
from kubernetes.client.exceptions import ApiException
Expand Down Expand Up @@ -950,6 +950,40 @@ def _get_cluster_uid(self, client: "ApiClient") -> str:

return cluster_uid

def _job_events(
self,
watch: kubernetes.watch.Watch,
batch_client: kubernetes.client.BatchV1Api,
job_name: str,
namespace: str,
watch_kwargs: dict,
) -> Generator[Union[Any, dict, str], Any, None]:
"""
Stream job events.
Pick up from the current resource version returned by the API
in the case of a 410.
See https://kubernetes.io/docs/reference/using-api/api-concepts/#efficient-detection-of-changes # noqa
"""
while True:
try:
return watch.stream(
func=batch_client.list_namespaced_job,
namespace=namespace,
field_selector=f"metadata.name={job_name}",
**watch_kwargs,
)
except ApiException as e:
if e.status == 410:
job_list = batch_client.list_namespaced_job(
namespace=namespace, field_selector=f"metadata.name={job_name}"
)
resource_version = job_list.metadata.resource_version
watch_kwargs["resource_version"] = resource_version
else:
raise

def _watch_job(
self,
logger: logging.Logger,
Expand Down Expand Up @@ -1029,18 +1063,18 @@ def _watch_job(
return -1

watch = kubernetes.watch.Watch()

# The kubernetes library will disable retries if the timeout kwarg is
# present regardless of the value so we do not pass it unless given
# https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160
timeout_seconds = (
{"timeout_seconds": remaining_time} if deadline else {}
)
watch_kwargs = {"timeout_seconds": remaining_time} if deadline else {}

for event in watch.stream(
func=batch_client.list_namespaced_job,
field_selector=f"metadata.name={job_name}",
namespace=configuration.namespace,
**timeout_seconds,
for event in self._job_events(
watch,
batch_client,
job_name,
configuration.namespace,
watch_kwargs,
):
if event["type"] == "DELETED":
logger.error(f"Job {job_name!r}: Job has been deleted.")
Expand Down
42 changes: 42 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,6 +2622,48 @@ def mock_stream(*args, **kwargs):

assert result.status_code == -1

async def test_watch_handles_410(
self,
default_configuration: KubernetesWorkerJobConfiguration,
flow_run,
mock_batch_client,
mock_core_client,
mock_watch,
):
mock_watch.stream.side_effect = [
_mock_pods_stream_that_returns_running_pod(),
_mock_pods_stream_that_returns_running_pod(),
ApiException(status=410),
_mock_pods_stream_that_returns_running_pod(),
]

job_list = MagicMock(spec=kubernetes.client.V1JobList)
job_list.metadata.resource_version = "1"

mock_batch_client.list_namespaced_job.side_effect = [job_list]

# The job should not be completed to start
mock_batch_client.read_namespaced_job.return_value.status.completion_time = None

async with KubernetesWorker(work_pool_name="test") as k8s_worker:
await k8s_worker.run(flow_run=flow_run, configuration=default_configuration)

mock_watch.stream.assert_has_calls(
[
mock.call(
func=mock_batch_client.list_namespaced_job,
namespace=mock.ANY,
field_selector="metadata.name=mock-job",
),
mock.call(
func=mock_batch_client.list_namespaced_job,
namespace=mock.ANY,
field_selector="metadata.name=mock-job",
resource_version="1",
),
]
)

class TestKillInfrastructure:
async def test_kill_infrastructure_calls_delete_namespaced_job(
self,
Expand Down

0 comments on commit cb530ad

Please sign in to comment.