Skip to content

Commit

Permalink
feat: Implement default lf tags for database (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
svdimchenko authored Sep 5, 2023
1 parent ded2747 commit b5214ba
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 18 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ A dbt profile can be configured to run against AWS Athena using the following co
| aws_profile_name | Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` |

**Example profiles.yml entry:**
```yaml
Expand Down
4 changes: 4 additions & 0 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class AthenaCredentials(Credentials):
num_retries: Optional[int] = 5
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "schema_table_unique"
# Unfortunately we can not just use dict, must by Dict because we'll get the following error:
# Credentials in profile "athena", target "athena" invalid: Unable to create schema for 'dict'
lf_tags_database: Optional[Dict[str, str]] = None

@property
def type(self) -> str:
Expand All @@ -83,6 +86,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
"s3_data_dir",
"s3_data_naming",
"debug_query_state",
"lf_tags_database",
)


Expand Down
13 changes: 13 additions & 0 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp"

@available
def add_lf_tags_to_database(self, relation: AthenaRelation) -> None:
conn = self.connections.get_thread_connection()
client = conn.handle
if lf_tags := conn.credentials.lf_tags_database:
config = LfTagsConfig(enabled=True, tags=lf_tags)
with boto3_client_lock:
lf_client = client.session.client("lakeformation", client.region_name, config=get_boto3_config())
manager = LfTagsManager(lf_client, relation, config)
manager.process_lf_tags_database()
else:
LOGGER.debug(f"Lakeformation is disabled for {relation}")

@available
def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any]) -> None:
config = LfTagsConfig(**lf_tags_config)
Expand Down
31 changes: 20 additions & 11 deletions dbt/adapters/athena/lakeformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_
self.lf_tags = lf_tags_config.tags
self.lf_tags_columns = lf_tags_config.tags_columns

def process_lf_tags_database(self) -> None:
if self.lf_tags:
database_resource = {"Database": {"Name": self.database}}
response = self.lf_client.add_lf_tags_to_resource(
Resource=database_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()]
)
self._parse_and_log_lf_response(response, None, self.lf_tags)

def process_lf_tags(self) -> None:
table_resource = {"Table": {"DatabaseName": self.database, "Name": self.table}}
existing_lf_tags = self.lf_client.get_resource_lf_tags(Resource=table_resource)
Expand Down Expand Up @@ -65,7 +73,7 @@ def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTyp
response = self.lf_client.remove_lf_tags_from_resource(
Resource=resource, LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}]
)
logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value}, "remove"))
self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove")

def _apply_lf_tags_table(
self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef
Expand All @@ -84,13 +92,13 @@ def _apply_lf_tags_table(
response = self.lf_client.remove_lf_tags_from_resource(
Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": v} for k, v in to_remove.items()]
)
logger.debug(self._parse_lf_response(response, None, self.lf_tags, "remove"))
self._parse_and_log_lf_response(response, None, self.lf_tags, "remove")

if self.lf_tags:
response = self.lf_client.add_lf_tags_to_resource(
Resource=table_resource, LFTags=[{"TagKey": k, "TagValues": [v]} for k, v in self.lf_tags.items()]
)
logger.debug(self._parse_lf_response(response, None, self.lf_tags))
self._parse_and_log_lf_response(response, None, self.lf_tags)

def _apply_lf_tags_columns(self) -> None:
if self.lf_tags_columns:
Expand All @@ -103,25 +111,26 @@ def _apply_lf_tags_columns(self) -> None:
Resource=resource,
LFTags=[{"TagKey": tag_key, "TagValues": [tag_value]}],
)
logger.debug(self._parse_lf_response(response, columns, {tag_key: tag_value}))
self._parse_and_log_lf_response(response, columns, {tag_key: tag_value})

def _parse_lf_response(
def _parse_and_log_lf_response(
self,
response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef],
columns: Optional[List[str]] = None,
lf_tags: Optional[Dict[str, str]] = None,
verb: str = "add",
) -> str:
failures = response.get("Failures", [])
) -> None:
table_appendix = f".{self.table}" if self.table else ""
columns_appendix = f" for columns {columns}" if columns else ""
if failures:
base_msg = f"Failed to {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix
resource_msg = self.database + table_appendix + columns_appendix
if failures := response.get("Failures", []):
base_msg = f"Failed to {verb} LF tags: {lf_tags} to " + resource_msg
for failure in failures:
tag = failure.get("LFTag", {}).get("TagKey")
error = failure.get("Error", {}).get("ErrorMessage")
logger.error(f"Failed to {verb} {tag} for {self.database}.{self.table}" + f" - {error}")
logger.error(f"Failed to {verb} {tag} for " + resource_msg + f" - {error}")
raise DbtRuntimeError(base_msg)
return f"Success: {verb} LF tags: {lf_tags} to {self.database}.{self.table}" + columns_appendix
logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg)


class FilterConfig(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions dbt/include/athena/macros/adapters/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
{%- call statement('create_schema') -%}
create schema if not exists {{ relation.without_identifier().render_hive() }}
{% endcall %}

{{ adapter.add_lf_tags_to_database(relation) }}

{% endmacro %}


Expand Down
28 changes: 21 additions & 7 deletions tests/unit/test_lakeformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# get_resource_lf_tags
class TestLfTagsManager:
@pytest.mark.parametrize(
"response,columns,lf_tags,verb,expected",
"response,identifier,columns,lf_tags,verb,expected",
[
pytest.param(
{
Expand All @@ -22,6 +22,7 @@ class TestLfTagsManager:
}
]
},
"tbl_name",
["column1", "column2"],
{"tag_key": "tag_value"},
"add",
Expand All @@ -31,32 +32,45 @@ class TestLfTagsManager:
),
pytest.param(
{"Failures": []},
"tbl_name",
None,
{"tag_key": "tag_value"},
"add",
"Success: add LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name",
"Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name",
id="add lf_tag",
),
pytest.param(
{"Failures": []},
None,
None,
{"tag_key": "tag_value"},
"add",
"Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena",
id="add lf_tag_to_database",
),
pytest.param(
{"Failures": []},
"tbl_name",
None,
{"tag_key": "tag_value"},
"remove",
"Success: remove LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name",
"Success: remove LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name",
id="remove lf_tag",
),
pytest.param(
{"Failures": []},
"tbl_name",
["c1", "c2"],
{"tag_key": "tag_value"},
"add",
"Success: add LF tags: {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']",
"Success: add LF tags {'tag_key': 'tag_value'} to test_dbt_athena.tbl_name for columns ['c1', 'c2']",
id="lf_tag database table and columns",
),
],
)
def test__parse_lf_response(self, response, columns, lf_tags, verb, expected):
relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name")
def test__parse_lf_response(self, dbt_debug_caplog, response, identifier, columns, lf_tags, verb, expected):
relation = AthenaRelation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=identifier)
lf_client = boto3.client("lakeformation", region_name=AWS_REGION)
manager = LfTagsManager(lf_client, relation, LfTagsConfig())
assert manager._parse_lf_response(response, columns, lf_tags, verb) == expected
manager._parse_and_log_lf_response(response, columns, lf_tags, verb)
assert expected in dbt_debug_caplog.getvalue()

0 comments on commit b5214ba

Please sign in to comment.