Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A quick attempt at making taskproc handle migrations from one k8s clu… #213

Merged
merged 6 commits into from
Jul 11, 2024
79 changes: 53 additions & 26 deletions task_processing/plugins/kubernetes/kubernetes_pod_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Collection
from typing import Optional

from kubernetes import watch
from kubernetes import watch as kube_watch
from kubernetes.client import V1Affinity
from kubernetes.client import V1Container
from kubernetes.client import V1ContainerPort
Expand Down Expand Up @@ -72,13 +72,20 @@ def __init__(
kubeconfig_path: Optional[str] = None,
task_configs: Optional[Collection[KubernetesTaskConfig]] = [],
emit_events_without_state_transitions: bool = False,
old_kubeconfig_paths: Collection[str] = (),
jfongatyelp marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
if not version:
version = "unknown_task_processing"
user_agent = f"{namespace}/v{version}"
self.kube_client = KubeClient(
kubeconfig_path=kubeconfig_path, user_agent=user_agent
)

self.old_kube_clients = [
KubeClient(kubeconfig_path=old_kubeconfig_path, user_agent=user_agent)
for old_kubeconfig_path in old_kubeconfig_paths
]

self.namespace = namespace

# Pod modified events that did not result in a pod state transition are usually not
Expand Down Expand Up @@ -106,17 +113,23 @@ def __init__(

# TODO(TASKPROC-243): keep track of resourceVersion so that we can continue event processing
# from where we left off on restarts
self.watch = watch.Watch()
self.pod_event_watch_thread = threading.Thread(
target=self._pod_event_watch_loop,
# ideally this wouldn't be a daemon thread, but a watch.Watch() only checks
# if it should stop after receiving an event - and it's possible that we
# have periods with no events so instead we'll attempt to stop the watch
# and then join() with a small timeout to make sure that, if we shutdown
# with the thread alive, we did not drop any events
daemon=True,
)
self.pod_event_watch_thread.start()
self.pod_event_watch_threads = []
self.watches = []
for kube_client in [self.kube_client] + self.old_kube_clients:
watch = kube_watch.Watch()
pod_event_watch_thread = threading.Thread(
target=self._pod_event_watch_loop,
args=(kube_client, watch),
# ideally this wouldn't be a daemon thread, but a watch.Watch() only checks
# if it should stop after receiving an event - and it's possible that we
# have periods with no events so instead we'll attempt to stop the watch
# and then join() with a small timeout to make sure that, if we shutdown
# with the thread alive, we did not drop any events
daemon=True,
)
pod_event_watch_thread.start()
self.pod_event_watch_threads.append(pod_event_watch_thread)
self.watches.append(watch)

self.pending_event_processing_thread = threading.Thread(
target=self._pending_event_processing_loop,
Expand All @@ -143,7 +156,9 @@ def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None:
),
)

def _pod_event_watch_loop(self) -> None:
def _pod_event_watch_loop(
self, kube_client: KubeClient, watch: kube_watch.Watch
) -> None:
logger.debug(f"Starting watching Pod events for namespace={self.namespace}.")
# TODO(TASKPROC-243): we'll need to correctly handle resourceVersion expiration for the case
# where the gap between task_proc shutting down and coming back up is long enough for data
Expand All @@ -155,8 +170,8 @@ def _pod_event_watch_loop(self) -> None:
# see: https://github.com/kubernetes/kubernetes/issues/74022
while not self.stopping:
try:
for pod_event in self.watch.stream(
self.kube_client.core.list_namespaced_pod, self.namespace
for pod_event in watch.stream(
kube_client.core.list_namespaced_pod, self.namespace
):
# it's possible that we've received an event after we've already set the stop
# flag since Watch streams block forever, so re-check if we've stopped before
Expand All @@ -168,7 +183,7 @@ def _pod_event_watch_loop(self) -> None:
break
except ApiException as e:
if not self.stopping:
if not self.kube_client.maybe_reload_on_exception(exception=e):
if not kube_client.maybe_reload_on_exception(exception=e):
logger.exception(
"Unhandled API exception while watching Pod events - restarting watch!"
)
Expand Down Expand Up @@ -589,11 +604,18 @@ def run(self, task_config: KubernetesTaskConfig) -> Optional[str]:

def reconcile(self, task_config: KubernetesTaskConfig) -> None:
pod_name = task_config.pod_name
try:
pod = self.kube_client.get_pod(namespace=self.namespace, pod_name=pod_name)
except Exception:
logger.exception(f"Hit an exception attempting to fetch pod {pod_name}")
pod = None
pod = None
for kube_client in [self.kube_client] + self.old_kube_clients:
try:
pod = kube_client.get_pod(namespace=self.namespace, pod_name=pod_name)
except Exception:
logger.exception(
f"Hit an exception attempting to fetch pod {pod_name} from {kube_client.kubeconfig_path}"
)
else:
# kube_client.get_pod will return None with no exception if it sees a 404 from API
if pod:
break

if pod_name not in self.task_metadata:
self._initialize_existing_task(task_config)
Expand Down Expand Up @@ -640,9 +662,12 @@ def kill(self, task_id: str) -> bool:
This function will request that Kubernetes delete the named Pod and will return
True if the Pod termination request was succesfully emitted or False otherwise.
"""
terminated = self.kube_client.terminate_pod(
namespace=self.namespace,
pod_name=task_id,
terminated = any(
kube_client.terminate_pod(
namespace=self.namespace,
pod_name=task_id,
)
for kube_client in [self.kube_client] + self.old_kube_clients
)
if terminated:
logger.info(
Expand Down Expand Up @@ -678,12 +703,14 @@ def stop(self) -> None:
logger.debug("Signaling Pod event Watch to stop streaming events...")
# make sure that we've stopped watching for events before calling join() - otherwise,
# join() will block until we hit the configured timeout (or forever with no timeout).
self.watch.stop()
for watch in self.watches:
watch.stop()
# timeout arbitrarily chosen - we mostly just want to make sure that we have a small
# grace period to flush the current event to the pending_events queue as well as
# any other clean-up - it's possible that after this join() the thread is still alive
# but in that case we can be reasonably sure that we're not dropping any data.
self.pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)
for pod_event_watch_thread in self.pod_event_watch_threads:
pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)

logger.debug("Waiting for all pending PodEvents to be processed...")
# once we've stopped updating the pending events queue, we then wait until we're done
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def k8s_executor(mock_Thread):
executor.stop()


@pytest.fixture
def k8s_executor_with_old_clusters(mock_Thread):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
), mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_client", autospec=True
), mock.patch.dict(
os.environ, {"KUBECONFIG": "/this/doesnt/exist.conf"}
):
executor = KubernetesPodExecutor(
namespace="task_processing_tests",
old_kubeconfig_paths=["/this/also/doesnt/exist.conf"],
)
yield executor
executor.stop()


@pytest.fixture
def mock_task_configs():
test_task_names = ["job1.action1", "job1.action2", "job2.action1", "job3.action2"]
Expand Down Expand Up @@ -86,6 +104,18 @@ def k8s_executor_with_tasks(mock_Thread, mock_task_configs):
executor.stop()


def test_init_watch_setup(k8s_executor):
assert len(k8s_executor.watches) == len(k8s_executor.pod_event_watch_threads) == 1


def test_init_watch_setup_multicluster(k8s_executor_with_old_clusters):
assert (
len(k8s_executor_with_old_clusters.watches)
== len(k8s_executor_with_old_clusters.pod_event_watch_threads)
== 2
)


def test_run_updates_task_metadata(k8s_executor):
task_config = KubernetesTaskConfig(
name="name", uuid="uuid", image="fake_image", command="fake_command"
Expand Down Expand Up @@ -866,6 +896,47 @@ def test_reconcile_missing_pod(
assert tm.task_state == KubernetesTaskState.TASK_LOST


def test_reconcile_multicluster(
k8s_executor_with_old_clusters,
):
task_config = mock.Mock(spec=KubernetesTaskConfig)
task_config.pod_name = "pod--name.uuid"
task_config.name = "job-name"

k8s_executor_with_old_clusters.task_metadata = pmap(
{
task_config.pod_name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_state=KubernetesTaskState.TASK_UNKNOWN,
task_state_history=v(),
)
}
)

mock_old_kube_client = mock.Mock(autospec=True)
mock_found_pod = mock.Mock(spec=V1Pod)
mock_found_pod.metadata.name = task_config.pod_name
mock_found_pod.status.phase = "Running"
mock_found_pod.status.host_ip = "1.2.3.4"
mock_found_pod.spec.node_name = "kubenode"
mock_old_kube_client.get_pod.return_value = mock_found_pod
mock_old_kube_clients = [mock_old_kube_client]

with mock.patch.object(
k8s_executor_with_old_clusters, "kube_client", autospec=True
) as mock_kube_client, mock.patch.object(
k8s_executor_with_old_clusters, "old_kube_clients", mock_old_kube_clients
):
mock_kube_client.get_pod.return_value = None
k8s_executor_with_old_clusters.reconcile(task_config)

mock_old_kube_client.get_pod.assert_called()
assert k8s_executor_with_old_clusters.event_queue.qsize() == 1
assert len(k8s_executor_with_old_clusters.task_metadata) == 1
tm = k8s_executor_with_old_clusters.task_metadata["pod--name.uuid"]
assert tm.task_state == KubernetesTaskState.TASK_RUNNING


def test_reconcile_existing_pods(k8s_executor, mock_task_configs):
mock_pods = []
test_phases = ["Running", "Succeeded", "Failed", "Unknown"]
Expand Down
Loading