From 0e08feb8890bed3e2ee698bb438f01fa55e0e009 Mon Sep 17 00:00:00 2001 From: Benjamin Dornel Date: Fri, 6 Oct 2023 15:39:59 +0800 Subject: [PATCH] Add AWS Athena profile mapping --- cosmos/profiles/__init__.py | 2 + cosmos/profiles/athena/__init__.py | 5 + cosmos/profiles/athena/access_key.py | 59 +++++++ pyproject.toml | 4 + .../profiles/athena/test_athena_access_key.py | 154 ++++++++++++++++++ 5 files changed, 224 insertions(+) create mode 100644 cosmos/profiles/athena/__init__.py create mode 100644 cosmos/profiles/athena/access_key.py create mode 100644 tests/profiles/athena/test_athena_access_key.py diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index 752a84ff3..dae6e2c04 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -5,6 +5,7 @@ from typing import Any, Type +from .athena import AthenaAccessKeyProfileMapping from .base import BaseProfileMapping from .bigquery.service_account_file import GoogleCloudServiceAccountFileProfileMapping from .bigquery.service_account_keyfile_dict import GoogleCloudServiceAccountDictProfileMapping @@ -21,6 +22,7 @@ from .trino.ldap import TrinoLDAPProfileMapping profile_mappings: list[Type[BaseProfileMapping]] = [ + AthenaAccessKeyProfileMapping, GoogleCloudServiceAccountFileProfileMapping, GoogleCloudServiceAccountDictProfileMapping, GoogleCloudOauthProfileMapping, diff --git a/cosmos/profiles/athena/__init__.py b/cosmos/profiles/athena/__init__.py new file mode 100644 index 000000000..0cbb09a7c --- /dev/null +++ b/cosmos/profiles/athena/__init__.py @@ -0,0 +1,5 @@ +"Athena Airflow connection -> dbt profile mappings" + +from .access_key import AthenaAccessKeyProfileMapping + +__all__ = ["AthenaAccessKeyProfileMapping"] diff --git a/cosmos/profiles/athena/access_key.py b/cosmos/profiles/athena/access_key.py new file mode 100644 index 000000000..b79bb793a --- /dev/null +++ b/cosmos/profiles/athena/access_key.py @@ -0,0 +1,59 @@ +"Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key." +from __future__ import annotations + +from typing import Any + +from ..base import BaseProfileMapping + + +class AthenaAccessKeyProfileMapping(BaseProfileMapping): + """ + Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key. + + https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup + https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html + """ + + airflow_connection_type: str = "aws" + dbt_profile_type: str = "athena" + is_community: bool = True + + required_fields = [ + "aws_access_key_id", + "aws_secret_access_key", + "database", + "region_name", + "s3_staging_dir", + "schema", + ] + secret_fields = [ + "aws_secret_access_key", + ] + airflow_param_mapping = { + "aws_access_key_id": "login", + "aws_secret_access_key": "password", + "aws_profile_name": "extra.aws_profile_name", + "database": "extra.database", + "debug_query_state": "extra.debug_query_state", + "lf_tags_database": "extra.lf_tags_database", + "num_retries": "extra.num_retries", + "poll_interval": "extra.poll_interval", + "region_name": "extra.region_name", + "s3_data_dir": "extra.s3_data_dir", + "s3_data_naming": "extra.s3_data_naming", + "s3_staging_dir": "extra.s3_staging_dir", + "schema": "extra.schema", + "seed_s3_upload_args": "extra.seed_s3_upload_args", + "work_group": "extra.work_group", + } + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile. The password is stored in an environment variable." + profile = { + **self.mapped_params, + **self.profile_args, + # aws_secret_access_key should always get set as env var + "aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), + } + return self.filter_null(profile) diff --git a/pyproject.toml b/pyproject.toml index 11e819774..d194779f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ [project.optional-dependencies] dbt-all = [ + "dbt-athena", "dbt-bigquery", "dbt-databricks", "dbt-exasol", @@ -54,6 +55,9 @@ dbt-all = [ "dbt-snowflake", "dbt-spark", ] +dbt-athena = [ + "dbt-athena-community", +] dbt-bigquery = [ "dbt-bigquery", ] diff --git a/tests/profiles/athena/test_athena_access_key.py b/tests/profiles/athena/test_athena_access_key.py new file mode 100644 index 000000000..2063ef6ed --- /dev/null +++ b/tests/profiles/athena/test_athena_access_key.py @@ -0,0 +1,154 @@ +"Tests for the Athena profile." + +import json +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping + + +@pytest.fixture() +def mock_athena_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_athena_connection", + conn_type="aws", + login="my_aws_access_key_id", + password="my_aws_secret_key", + extra=json.dumps( + { + "database": "my_database", + "region_name": "my_region", + "s3_staging_dir": "s3://my_bucket/dbt/", + "schema": "my_schema", + } + ), + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_athena_connection_claiming() -> None: + """ + Tests that the Athena profile mapping claims the correct connection type. + """ + # should only claim when: + # - conn_type == aws + # and the following exist: + # - login + # - password + # - database + # - region_name + # - s3_staging_dir + # - schema + potential_values = { + "conn_type": "aws", + "login": "my_aws_access_key_id", + "password": "my_aws_secret_key", + "extra": json.dumps( + { + "database": "my_database", + "region_name": "my_region", + "s3_staging_dir": "s3://my_bucket/dbt/", + "schema": "my_schema", + } + ), + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + # should raise an InvalidMappingException + profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) + assert profile_mapping.can_claim_connection() + + +def test_athena_profile_mapping_selected( + mock_athena_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected for Athena. + """ + profile_mapping = get_automatic_profile_mapping( + mock_athena_conn.conn_id, + ) + assert isinstance(profile_mapping, AthenaAccessKeyProfileMapping) + + +def test_athena_profile_args( + mock_athena_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly for Athena. + """ + profile_mapping = get_automatic_profile_mapping( + mock_athena_conn.conn_id, + ) + + assert profile_mapping.profile == { + "type": "athena", + "aws_access_key_id": mock_athena_conn.login, + "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "database": mock_athena_conn.extra_dejson.get("database"), + "region_name": mock_athena_conn.extra_dejson.get("region_name"), + "s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), + "schema": mock_athena_conn.extra_dejson.get("schema"), + } + + +def test_athena_profile_args_overrides( + mock_athena_conn: Connection, +) -> None: + """ + Tests that you can override the profile values for Athena. + """ + profile_mapping = get_automatic_profile_mapping( + mock_athena_conn.conn_id, + profile_args={"schema": "my_custom_schema", "database": "my_custom_db"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_custom_schema", + "database": "my_custom_db", + } + + assert profile_mapping.profile == { + "type": "athena", + "aws_access_key_id": mock_athena_conn.login, + "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", + "database": "my_custom_db", + "region_name": mock_athena_conn.extra_dejson.get("region_name"), + "s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), + "schema": "my_custom_schema", + } + + +def test_athena_profile_env_vars( + mock_athena_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly for Athena. + """ + profile_mapping = get_automatic_profile_mapping( + mock_athena_conn.conn_id, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, + }