diff --git a/README.md b/README.md index 15b5a6c9..7b75bada 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ A dbt profile can be configured to run against AWS Athena using the following co | aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` | | work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | | num_retries | Number of times to retry a failing query | Optional | `3` | +| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` | **Example profiles.yml entry:** ```yaml diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 89363d92..b0976b94 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -59,6 +59,9 @@ class AthenaCredentials(Credentials): num_retries: Optional[int] = 5 s3_data_dir: Optional[str] = None s3_data_naming: Optional[str] = "schema_table_unique" + # Unfortunately we can not just use dict, must by Dict because we'll get the following error: + # Credentials in profile "athena", target "athena" invalid: Unable to create schema for 'dict' + lf_tags_database: Optional[Dict[str, str]] = None @property def type(self) -> str: @@ -83,6 +86,7 @@ def _connection_keys(self) -> Tuple[str, ...]: "s3_data_dir", "s3_data_naming", "debug_query_state", + "lf_tags_database", ) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 48b56dab..75585cc2 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -94,6 +94,19 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" + @available + def add_lf_tags_to_database(self, relation: AthenaRelation) -> None: + conn = self.connections.get_thread_connection() + client = conn.handle + if lf_tags := conn.credentials.lf_tags_database: + config = LfTagsConfig(enabled=True, tags=lf_tags) + with boto3_client_lock: + lf_client = client.session.client("lakeformation", client.region_name, config=get_boto3_config()) + manager = LfTagsManager(lf_client, relation, config) + manager.process_lf_tags_database() + else: + LOGGER.debug(f"Lakeformation is disabled for {relation}") + @available def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any]) -> None: config = LfTagsConfig(**lf_tags_config) diff --git a/dbt/adapters/athena/lakeformation.py b/dbt/adapters/athena/lakeformation.py index cd29e7e6..9fc8047b 100644 --- a/dbt/adapters/athena/lakeformation.py +++ b/dbt/adapters/athena/lakeformation.py @@ -34,6 +34,14 @@ def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_ self.lf_tags = lf_tags_config.tags self.lf_tags_columns = lf_tags_config.tags_columns + def process_lf_tags_database(self) -> None: + if self.lf_tags: + database_resource = {"Database": {"Name": self.database}} + response = self.lf_client.add_lf_tags_to_resource( + Resource=database_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] + ) + self._parse_and_log_lf_response(response, None, self.lf_tags) + def process_lf_tags(self) -> None: table_resource = {"Table": {"DatabaseName": self.database, "Name": self.table}} existing_lf_tags = self.lf_client.get_resource_lf_tags(Resource=table_resource) @@ -65,7 +73,7 @@ def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTyp response = self.lf_client.remove_lf_tags_from_resource( Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}] ) - logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value}, "remove")) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove") def _apply_lf_tags_table( self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef @@ -84,13 +92,13 @@ def _apply_lf_tags_table( response = self.lf_client.remove_lf_tags_from_resource( Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": v} for k, v in to_remove.items()] ) - logger.debug(self._parse_lf_response(response, None, self.lf_tags, "remove")) + self._parse_and_log_lf_response(response, None, self.lf_tags, "remove") if self.lf_tags: response = self.lf_client.add_lf_tags_to_resource( Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()] ) - logger.debug(self._parse_lf_response(response, None, self.lf_tags)) + self._parse_and_log_lf_response(response, None, self.lf_tags) def _apply_lf_tags_columns(self) -> None: if self.lf_tags_columns: @@ -103,25 +111,26 @@ def _apply_lf_tags_columns(self) -> None: Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}], ) - logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value})) + self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}) - def _parse_lf_response( + def _parse_and_log_lf_response( self, response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef], columns: Optional[List[str]] = None, lf_tags: Optional[Dict[str, str]] = None, verb: str = "add", - ) -> str: - failures = response.get("Failures", []) + ) -> None: + table_appendix = f".{self.table}" if self.table else "" columns_appendix = f" for columns {columns}" if columns else "" - if failures: - base_msg = f"Failed to {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix + resource_msg = self.database + table_appendix + columns_appendix + if failures := response.get("Failures", []): + base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg for failure in failures: tag = failure.get("LFTag", {}).get("TagKey") error = failure.get("Error", {}).get("ErrorMessage") - logger.error(f"Failed to {verb} {tag} for {self.database}.{self.table}" + f" - {error}") + logger.error(f"Failed to {verb} {tag} for " + resource_msg + f" - {error}") raise DbtRuntimeError(base_msg) - return f"Success: {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix + logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg) class FilterConfig(BaseModel): diff --git a/dbt/include/athena/macros/adapters/schema.sql b/dbt/include/athena/macros/adapters/schema.sql index 777a3690..2750a7f9 100644 --- a/dbt/include/athena/macros/adapters/schema.sql +++ b/dbt/include/athena/macros/adapters/schema.sql @@ -2,6 +2,9 @@ {%- call statement('create_schema') -%} create schema if not exists {{ relation.without_identifier().render_hive() }} {% endcall %} + + {{ adapter.add_lf_tags_to_database(relation) }} + {% endmacro %} diff --git a/tests/unit/test_lakeformation.py b/tests/unit/test_lakeformation.py index d5025e2b..ab061c09 100644 --- a/tests/unit/test_lakeformation.py +++ b/tests/unit/test_lakeformation.py @@ -11,7 +11,7 @@ # get_resource_lf_tags class TestLfTagsManager: @pytest.mark.parametrize( - "response,columns,lf_tags,verb,expected", + "response,identifier,columns,lf_tags,verb,expected", [ pytest.param( { @@ -22,6 +22,7 @@ class TestLfTagsManager: } ] }, + "tbl_name", ["column1", "column2"], {"tag_key": "tag_value"}, "add", @@ -31,32 +32,45 @@ class TestLfTagsManager: ), pytest.param( {"Failures": []}, + "tbl_name", None, {"tag_key": "tag_value"}, "add", - "Success: add LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", id="add lf_tag", ), pytest.param( {"Failures": []}, None, + None, + {"tag_key": "tag_value"}, + "add", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena", + id="add lf_tag_to_database", + ), + pytest.param( + {"Failures": []}, + "tbl_name", + None, {"tag_key": "tag_value"}, "remove", - "Success: remove LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", + "Success: remove LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name", id="remove lf_tag", ), pytest.param( {"Failures": []}, + "tbl_name", ["c1", "c2"], {"tag_key": "tag_value"}, "add", - "Success: add LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']", + "Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']", id="lf_tag database table and columns", ), ], ) - def test__parse_lf_response(self, response, columns, lf_tags, verb, expected): - relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name") + def test__parse_lf_response(self, dbt_debug_caplog, response, identifier, columns, lf_tags, verb, expected): + relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=identifier) lf_client = boto3.client("lakeformation", region_name=AWS_REGION) manager = LfTagsManager(lf_client, relation, LfTagsConfig()) - assert manager._parse_lf_response(response, columns, lf_tags, verb) == expected + manager._parse_and_log_lf_response(response, columns, lf_tags, verb) + assert expected in dbt_debug_caplog.getvalue()