Skip to content

Commit

Permalink
Support for running dbt tasks in AWS EKS (#944)
Browse files Browse the repository at this point in the history
## Description
We are using MWAA in combination with EKS so that all our dags in
airflow are running in our EKS. We would like to use the same setup with
cosmos.

### What changes?
- New AwsEksOperator classes (inheriting from KubernetesOperators) -
Based on the original
[EksOperator](https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/operators/eks.py#L995)
- Tests
- Adjusted documentation

## Related Issue(s)

-

## Breaking Change?
No - only an additional feature

## Checklist

- [x] I have made corresponding changes to the documentation (if
required)
- [x] I have added tests that prove my fix is effective or that my
feature works

---------

Co-authored-by: Pankaj Koti <[email protected]>
  • Loading branch information
VolkerSchiewe and pankajkoti authored May 22, 2024
1 parent 007325a commit cb2a27a
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 2 deletions.
1 change: 1 addition & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ExecutionMode(Enum):
LOCAL = "local"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS_EKS = "aws_eks"
VIRTUALENV = "virtualenv"
AZURE_CONTAINER_INSTANCE = "azure_container_instance"

Expand Down
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def validate_initial_user_config(
"""
if profile_config is None and execution_config.execution_mode not in (
ExecutionMode.KUBERNETES,
ExecutionMode.AWS_EKS,
ExecutionMode.DOCKER,
):
raise CosmosValueError(f"The profile_config is mandatory when using {execution_config.execution_mode}")
Expand Down
131 changes: 131 additions & 0 deletions cosmos/operators/aws_eks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

from typing import Any, Sequence

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.utils.context import Context

from cosmos.operators.kubernetes import (
DbtBuildKubernetesOperator,
DbtKubernetesBaseOperator,
DbtLSKubernetesOperator,
DbtRunKubernetesOperator,
DbtRunOperationKubernetesOperator,
DbtSeedKubernetesOperator,
DbtSnapshotKubernetesOperator,
DbtTestKubernetesOperator,
)

DEFAULT_CONN_ID = "aws_default"
DEFAULT_NAMESPACE = "default"


class DbtAwsEksBaseOperator(DbtKubernetesBaseOperator):
template_fields: Sequence[str] = tuple(
{
"cluster_name",
"in_cluster",
"namespace",
"pod_name",
"aws_conn_id",
"region",
}
| set(DbtKubernetesBaseOperator.template_fields)
)

def __init__(
self,
cluster_name: str,
pod_name: str | None = None,
namespace: str | None = DEFAULT_NAMESPACE,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs: Any,
) -> None:
self.cluster_name = cluster_name
self.pod_name = pod_name
self.namespace = namespace
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(
name=self.pod_name,
namespace=self.namespace,
**kwargs,
)
# There is no need to manage the kube_config file, as it will be generated automatically.
# All Kubernetes parameters (except config_file) are also valid for the EksPodOperator.
if self.config_file:
raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.")

def execute(self, context: Context) -> Any | None: # type: ignore
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
with eks_hook.generate_config_file(
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
) as self.config_file:
return super().execute(context)


class DbtBuildAwsEksOperator(DbtAwsEksBaseOperator, DbtBuildKubernetesOperator):
"""
Executes a dbt core build command.
"""

template_fields: Sequence[str] = (
DbtAwsEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtLSAwsEksOperator(DbtAwsEksBaseOperator, DbtLSKubernetesOperator):
"""
Executes a dbt core ls command.
"""


class DbtSeedAwsEksOperator(DbtAwsEksBaseOperator, DbtSeedKubernetesOperator):
"""
Executes a dbt core seed command.
"""

template_fields: Sequence[str] = (
DbtAwsEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtSnapshotAwsEksOperator(DbtAwsEksBaseOperator, DbtSnapshotKubernetesOperator):
"""
Executes a dbt core snapshot command.
"""


class DbtRunAwsEksOperator(DbtAwsEksBaseOperator, DbtRunKubernetesOperator):
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = (
DbtAwsEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtTestAwsEksOperator(DbtAwsEksBaseOperator, DbtTestKubernetesOperator):
"""
Executes a dbt core test command.
"""

template_fields: Sequence[str] = (
DbtAwsEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtRunOperationAwsEksOperator(DbtAwsEksBaseOperator, DbtRunOperationKubernetesOperator):
"""
Executes a dbt core run-operation command.
"""

template_fields: Sequence[str] = (
DbtAwsEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator]
)
39 changes: 38 additions & 1 deletion docs/getting_started/execution-modes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Cosmos can run ``dbt`` commands using five different approaches, called ``execut
2. **virtualenv**: Run ``dbt`` commands from Python virtual environments managed by Cosmos
3. **docker**: Run ``dbt`` commands from Docker containers managed by Cosmos (requires a pre-existing Docker image)
4. **kubernetes**: Run ``dbt`` commands from Kubernetes Pods managed by Cosmos (requires a pre-existing Docker image)
5. **azure_container_instance**: Run ``dbt`` commands from Azure Container Instances managed by Cosmos (requires a pre-existing Docker image)
5. **aws_eks**: Run ``dbt`` commands from AWS EKS Pods managed by Cosmos (requires a pre-existing Docker image)
6. **azure_container_instance**: Run ``dbt`` commands from Azure Container Instances managed by Cosmos (requires a pre-existing Docker image)

The choice of the ``execution mode`` can vary based on each user's needs and concerns. For more details, check each execution mode described below.

Expand Down Expand Up @@ -38,6 +39,10 @@ The choice of the ``execution mode`` can vary based on each user's needs and con
- Slow
- High
- No
* - AWS_EKS
- Slow
- High
- No
* - Azure Container Instance
- Slow
- High
Expand Down Expand Up @@ -159,6 +164,38 @@ Example DAG:
"secrets": [postgres_password_secret],
},
)
AWS_EKS

Check warning on line 167 in docs/getting_started/execution-modes.rst

View workflow job for this annotation

GitHub Actions / pages

Explicit markup ends without a blank line; unexpected unindent.
----------

The ``aws_eks`` approach is very similar to the ``kubernetes`` approach, but it is specifically designed to run on AWS EKS clusters.
It uses the `EKSPodOperator <https://airflow.apache.org/docs/apache-airflow-providers-amazon/8.19.0/operators/eks.html#perform-a-task-on-an-amazon-eks-cluster>`_
to run the dbt commands. You need to provide the ``cluster_name`` in your operator_args to connect to the AWS EKS cluster.


Example DAG:

.. code-block:: python
postgres_password_secret = Secret(
deploy_type="env",
deploy_target="POSTGRES_PASSWORD",
secret="postgres-secrets",
key="password",
)
docker_cosmos_dag = DbtDag(
# ...
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.AWS_EKS,
),
operator_args={
"image": "dbt-jaffle-shop:1.0.0",
"cluster_name": CLUSTER_NAME,
"get_logs": True,
"is_delete_operator_pod": False,
"secrets": [postgres_password_secret],
},
)
Azure Container Instance
------------------------
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ docker = [
kubernetes = [
"apache-airflow-providers-cncf-kubernetes>=5.1.1",
]
aws_eks = [
"apache-airflow-providers-amazon>=8.0.0,<8.20.0", # https://github.com/apache/airflow/issues/39103
]
azure-container-instance = [
"apache-airflow-providers-microsoft-azure>=8.4.0",
]
Expand Down Expand Up @@ -120,6 +123,7 @@ dependencies = [
"astronomer-cosmos[tests]",
"apache-airflow-providers-postgres",
"apache-airflow-providers-cncf-kubernetes>=5.1.1",
"apache-airflow-providers-amazon>=3.0.0,<8.20.0", # https://github.com/apache/airflow/issues/39103
"apache-airflow-providers-docker>=3.5.0",
"apache-airflow-providers-microsoft-azure",
"types-PyYAML",
Expand All @@ -137,7 +141,7 @@ airflow = ["2.4", "2.5", "2.6", "2.7", "2.8", "2.9"]

[tool.hatch.envs.tests.overrides]
matrix.airflow.dependencies = [
{ value = "typing_extensions<4.6", if = ["2.6"] },
{ value = "typing_extensions<4.6", if = ["2.6"] }
]

[tool.hatch.envs.tests.scripts]
Expand Down
97 changes: 97 additions & 0 deletions tests/operators/test_aws_eks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from unittest.mock import MagicMock, patch

import pytest
from airflow.exceptions import AirflowException

from cosmos.operators.aws_eks import (
DbtBuildAwsEksOperator,
DbtLSAwsEksOperator,
DbtRunAwsEksOperator,
DbtSeedAwsEksOperator,
DbtTestAwsEksOperator,
)


@pytest.fixture()
def mock_kubernetes_execute():
with patch("cosmos.operators.kubernetes.KubernetesPodOperator.execute") as mock_execute:
yield mock_execute


base_kwargs = {
"conn_id": "my_airflow_connection",
"cluster_name": "my-cluster",
"task_id": "my-task",
"image": "my_image",
"project_dir": "my/dir",
"vars": {
"start_time": "{{ data_interval_start.strftime('%Y%m%d%H%M%S') }}",
"end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}",
},
"no_version_check": True,
}


def test_dbt_kubernetes_build_command():
"""
Since we know that the KubernetesOperator is tested, we can just test that the
command is built correctly and added to the "arguments" parameter.
"""

result_map = {
"ls": DbtLSAwsEksOperator(**base_kwargs),
"run": DbtRunAwsEksOperator(**base_kwargs),
"test": DbtTestAwsEksOperator(**base_kwargs),
"build": DbtBuildAwsEksOperator(**base_kwargs),
"seed": DbtSeedAwsEksOperator(**base_kwargs),
}

for command_name, command_operator in result_map.items():
command_operator.build_kube_args(context=MagicMock(), cmd_flags=MagicMock())
assert command_operator.arguments == [
"dbt",
command_name,
"--vars",
"end_time: '{{ data_interval_end.strftime(''%Y%m%d%H%M%S'') }}'\n"
"start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n",
"--no-version-check",
"--project-dir",
"my/dir",
]


@patch("cosmos.operators.kubernetes.DbtKubernetesBaseOperator.build_kube_args")
@patch("cosmos.operators.aws_eks.EksHook.generate_config_file")
def test_dbt_kubernetes_operator_execute(mock_generate_config_file, mock_build_kube_args, mock_kubernetes_execute):
"""Tests that the execute method call results in both the build_kube_args method and the kubernetes execute method being called."""
operator = DbtLSAwsEksOperator(
conn_id="my_airflow_connection",
cluster_name="my-cluster",
task_id="my-task",
image="my_image",
project_dir="my/dir",
)
operator.execute(context={})
# Assert that the build_kube_args method was called in the execution
mock_build_kube_args.assert_called_once()

# Assert that the generate_config_file method was called in the execution to create the kubeconfig for eks
mock_generate_config_file.assert_called_once_with(eks_cluster_name="my-cluster", pod_namespace="default")

# Assert that the kubernetes execute method was called in the execution
mock_kubernetes_execute.assert_called_once()
assert mock_kubernetes_execute.call_args.args[-1] == {}


def test_provided_config_file_fails():
"""Tests that the constructor fails if it is called with a config_file."""
with pytest.raises(AirflowException) as err_context:
DbtLSAwsEksOperator(
conn_id="my_airflow_connection",
cluster_name="my-cluster",
task_id="my-task",
image="my_image",
project_dir="my/dir",
config_file="my/config",
)
assert "The config_file is not an allowed parameter for the EksPodOperator." in str(err_context.value)

0 comments on commit cb2a27a

Please sign in to comment.