Skip to content

Commit

Permalink
Add support for Kubernetes on_warning_callback (#673)
Browse files Browse the repository at this point in the history
To make `on_warning_callback` work with pod operators, we need
to read the logs of the dbt test runs. This is done by ensuring the
pod is kept alive, and `on_success_callback` the log is read and
analysed for warnings.

Afterwards, the pod is cleaned up based on the original settings from
the user.

If `on_warning_callback` is not set, everything stays the way it always
was.

This feature only work with `apache-airflow-providers-cncf-kubernetes >= 7.4.0`.
  • Loading branch information
david-mag authored Nov 22, 2023
1 parent a83911f commit 0b538a5
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 5 deletions.
100 changes: 96 additions & 4 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from typing import Any, Callable, Sequence

import yaml
from airflow.utils.context import Context
from airflow.utils.context import Context, context_merge

from cosmos.log import get_logger
from cosmos.config import ProfileConfig
from cosmos.operators.base import DbtBaseOperator

from airflow.models import TaskInstance
from cosmos.dbt.parser.output import extract_log_issues

DBT_NO_TESTS_MSG = "Nothing to do"
DBT_WARN_MSG = "WARN"

logger = get_logger(__name__)

Expand All @@ -19,6 +24,7 @@
convert_env_vars,
)
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
except ImportError:
try:
# apache-airflow-providers-cncf-kubernetes < 7.4.0
Expand Down Expand Up @@ -158,10 +164,96 @@ class DbtTestKubernetesOperator(DbtKubernetesBaseOperator):
ui_color = "#8194E0"

def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not on_warning_callback:
super().__init__(**kwargs)
else:
self.on_warning_callback = on_warning_callback
self.is_delete_operator_pod_original = kwargs.get("is_delete_operator_pod", None)
if self.is_delete_operator_pod_original is not None:
self.on_finish_action_original = (
OnFinishAction.DELETE_POD if self.is_delete_operator_pod_original else OnFinishAction.KEEP_POD
)
else:
self.on_finish_action_original = OnFinishAction(kwargs.get("on_finish_action", "delete_pod"))
self.is_delete_operator_pod_original = self.on_finish_action_original == OnFinishAction.DELETE_POD
# In order to read the pod logs, we need to keep the pod around.
# Depending on the on_finish_action & is_delete_operator_pod settings,
# we will clean up the pod later in the _handle_warnings method, which
# is called in on_success_callback.
kwargs["is_delete_operator_pod"] = False
kwargs["on_finish_action"] = OnFinishAction.KEEP_POD

# Add an additional callback to both success and failure callbacks.
# In case of success, check for a warning in the logs and clean up the pod.
self.on_success_callback = kwargs.get("on_success_callback", None) or []
if isinstance(self.on_success_callback, list):
self.on_success_callback += [self._handle_warnings]
else:
self.on_success_callback = [self.on_success_callback, self._handle_warnings]
kwargs["on_success_callback"] = self.on_success_callback
# In case of failure, clean up the pod.
self.on_failure_callback = kwargs.get("on_failure_callback", None) or []
if isinstance(self.on_failure_callback, list):
self.on_failure_callback += [self._cleanup_pod]
else:
self.on_failure_callback = [self.on_failure_callback, self._cleanup_pod]
kwargs["on_failure_callback"] = self.on_failure_callback

super().__init__(**kwargs)

self.base_cmd = ["test"]
# as of now, on_warning_callback in kubernetes executor does nothing
self.on_warning_callback = on_warning_callback

def _handle_warnings(self, context: Context) -> None:
"""
Handles warnings by extracting log issues, creating additional context, and calling the
on_warning_callback with the updated context.
:param context: The original airflow context in which the build and run command was executed.
"""
if not (
isinstance(context["task_instance"], TaskInstance)
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
):
return
task = context["task_instance"].task
logs = [
log.decode("utf-8") for log in task.pod_manager.read_pod_logs(task.pod, "base") if log.decode("utf-8") != ""
]

should_trigger_callback = all(
[
logs,
self.on_warning_callback,
DBT_NO_TESTS_MSG not in logs[-1],
DBT_WARN_MSG in logs[-1],
]
)

if should_trigger_callback:
warnings = int(logs[-1].split(f"{DBT_WARN_MSG}=")[1].split()[0])
if warnings > 0:
test_names, test_results = extract_log_issues(logs)
context_merge(context, test_names=test_names, test_results=test_results)
self.on_warning_callback(context)

self._cleanup_pod(context)

def _cleanup_pod(self, context: Context) -> None:
"""
Handles the cleaning up of the pod after success or failure, if
there is a on_warning_callback function defined.
:param context: The original airflow context in which the build and run command was executed.
"""
if not (
isinstance(context["task_instance"], TaskInstance)
and isinstance(context["task_instance"].task, DbtTestKubernetesOperator)
):
return
task = context["task_instance"].task
if task.pod:
task.on_finish_action = self.on_finish_action_original
task.cleanup(pod=task.pod, remote_pod=task.remote_pod)


class DbtRunOperationKubernetesOperator(DbtKubernetesBaseOperator):
Expand Down
119 changes: 118 additions & 1 deletion tests/operators/test_kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

from airflow.utils.context import Context
import pytest
from pendulum import datetime

from cosmos.operators.kubernetes import (
Expand All @@ -12,6 +12,16 @@
DbtTestKubernetesOperator,
)

from airflow.utils.context import Context, context_merge
from airflow.models import TaskInstance

try:
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction

module_available = True
except ImportError:
module_available = False


def test_dbt_kubernetes_operator_add_global_flags() -> None:
dbt_kube_operator = DbtKubernetesBaseOperator(
Expand Down Expand Up @@ -103,6 +113,113 @@ def test_dbt_kubernetes_build_command():
]


@pytest.mark.parametrize(
"additional_kwargs,expected_results",
[
({"on_success_callback": None, "is_delete_operator_pod": True}, (1, 1, True, "delete_pod")),
(
{"on_success_callback": (lambda **kwargs: None), "is_delete_operator_pod": False},
(2, 1, False, "keep_pod"),
),
(
{"on_success_callback": [(lambda **kwargs: None), (lambda **kwargs: None)], "is_delete_operator_pod": None},
(3, 1, True, "delete_pod"),
),
(
{"on_failure_callback": None, "is_delete_operator_pod": True, "on_finish_action": "keep_pod"},
(1, 1, True, "delete_pod"),
),
(
{
"on_failure_callback": (lambda **kwargs: None),
"is_delete_operator_pod": None,
"on_finish_action": "delete_pod",
},
(1, 2, True, "delete_pod"),
),
(
{
"on_failure_callback": [(lambda **kwargs: None), (lambda **kwargs: None)],
"is_delete_operator_pod": None,
"on_finish_action": "delete_succeeded_pod",
},
(1, 3, False, "delete_succeeded_pod"),
),
({"is_delete_operator_pod": None, "on_finish_action": "keep_pod"}, (1, 1, False, "keep_pod")),
({}, (1, 1, True, "delete_pod")),
],
)
@pytest.mark.skipif(
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
)
def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_results):
test_operator = DbtTestKubernetesOperator(
on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs
)

print(additional_kwargs, test_operator.__dict__)

assert isinstance(test_operator.on_success_callback, list)
assert isinstance(test_operator.on_failure_callback, list)
assert test_operator._handle_warnings in test_operator.on_success_callback
assert test_operator._cleanup_pod in test_operator.on_failure_callback
assert len(test_operator.on_success_callback) == expected_results[0]
assert len(test_operator.on_failure_callback) == expected_results[1]
assert test_operator.is_delete_operator_pod_original == expected_results[2]
assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3])


class FakePodManager:
def read_pod_logs(self, pod, container):
assert pod == "pod"
assert container == "base"
log_string = """
19:48:25 Concurrency: 4 threads (target='target')
19:48:25
19:48:25 1 of 2 START test dbt_utils_accepted_range_table_col__12__0 ................... [RUN]
19:48:25 2 of 2 START test unique_table__uuid .......................................... [RUN]
19:48:27 1 of 2 WARN 252 dbt_utils_accepted_range_table_col__12__0 ..................... [WARN 117 in 1.83s]
19:48:27 2 of 2 PASS unique_table__uuid ................................................ [PASS in 1.85s]
19:48:27
19:48:27 Finished running 2 tests, 1 hook in 0 hours 0 minutes and 12.86 seconds (12.86s).
19:48:27
19:48:27 Completed with 1 warning:
19:48:27
19:48:27 Warning in test dbt_utils_accepted_range_table_col__12__0 (models/ads/ads.yaml)
19:48:27 Got 252 results, configured to warn if >0
19:48:27
19:48:27 compiled Code at target/compiled/model/models/table/table.yaml/dbt_utils_accepted_range_table_col__12__0.sql
19:48:27
19:48:27 Done. PASS=1 WARN=1 ERROR=0 SKIP=0 TOTAL=2
"""
return (log.encode("utf-8") for log in log_string.split("\n"))


@pytest.mark.skipif(
not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available"
)
def test_dbt_test_kubernetes_operator_handle_warnings_and_cleanup_pod():
def on_warning_callback(context: Context):
assert context["test_names"] == ["dbt_utils_accepted_range_table_col__12__0"]
assert context["test_results"] == ["Got 252 results, configured to warn if >0"]

def cleanup(pod: str, remote_pod: str):
assert pod == remote_pod

test_operator = DbtTestKubernetesOperator(
is_delete_operator_pod=True, on_warning_callback=on_warning_callback, **base_kwargs
)
task_instance = TaskInstance(test_operator)
task_instance.task.pod_manager = FakePodManager()
task_instance.task.pod = task_instance.task.remote_pod = "pod"
task_instance.task.cleanup = cleanup

context = Context()
context_merge(context, task_instance=task_instance)

test_operator._handle_warnings(context)


@patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.hook")
def test_created_pod(test_hook):
test_hook.is_in_cluster = False
Expand Down

0 comments on commit 0b538a5

Please sign in to comment.