diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..d660f8c8 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,2 @@ +github: [dbt-athena] +open_collective: dbt-athena diff --git a/README.md b/README.md index 717a2875..14028342 100644 --- a/README.md +++ b/README.md @@ -259,17 +259,39 @@ athena: } ``` -> Notes: +* `lf_inherited_tags` (`default=none`) + * List of Lake Formation tag keys that are intended to be inherited from the database level and thus shouldn't be + removed during association of those defined in `lf_tags_config` + * i.e. The default behavior of `lf_tags_config` is to be exhaustive and first remove any pre-existing tags from + tables and columns before associating the ones currently defined for a given model + * This breaks tag inheritance as inherited tags appear on tables and columns like those associated directly + * This list sits outside of `lf_tags_config` so that it can be set at the project level -- for example: + +```yaml +models: + my_project: + example: + +lf_inherited_tags: [inherited-tag-1, inherited-tag-2] +``` + +> Notes: > > * `lf_tags` and `lf_tags_columns` configs support only attaching lf tags to corresponding resources. - > We recommend managing LF Tags permissions somewhere outside dbt. For example, you may use - > [terraform](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_permissions) - or - > [aws cdk](https://docs.aws.amazon.com/cdk/api/v1/docs/aws-lakeformation-readme.html) for such purpose. +> We recommend managing LF Tags permissions somewhere outside dbt. For example, you may use +> [terraform](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/lakeformation_permissions) or +> [aws cdk](https://docs.aws.amazon.com/cdk/api/v1/docs/aws-lakeformation-readme.html) for such purpose. > * `data_cell_filters` management can't be automated outside dbt because the filter can't be attached to the table - > which doesn't exist. Once you `enable` this config, dbt will set all filters and their permissions during every - > dbt run. Such approach keeps the actual state of row level security configuration actual after every dbt run and - > apply changes if they occur: drop, create, update filters and their permissions. +> which doesn't exist. Once you `enable` this config, dbt will set all filters and their permissions during every +> dbt run. Such approach keeps the actual state of row level security configuration actual after every dbt run and +> apply changes if they occur: drop, create, update filters and their permissions. +> * Any tags listed in `lf_inherited_tags` should be strictly inherited from the database level and never overridden at + the table and column level +> * Currently `dbt-athena` does not differentiate between an inherited tag association and an override of same it made +> previously +> * e.g. If an inherited tag is overridden by an `lf_tags_config` value in one DBT run, and that override is removed + prior to a subsequent run, the prior override will linger and no longer be encoded anywhere (in e.g. Terraform + where the inherited value is configured nor in the DBT project where the override previously existed but now is + gone) [create-table-as]: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties diff --git a/RELEASING.md b/RELEASING.md new file mode 100644 index 00000000..81e93309 --- /dev/null +++ b/RELEASING.md @@ -0,0 +1,11 @@ +# How to make a release + +* open a Pull Request with a manual bump of in `main/dbt/adapters/athena/__version__.py` +* create a new release from + * be sure to use the same version as in the `__version__.py` file + * be sure to start the release with `v` e.g. v1.6.3 + * tag with the same name of the release e.g. v1.6.3 + * be sure to clean up release notes grouping by semantic commit type, + e.g. all feat commits should under the same section +* Once the new release is made be sure that the new package version is available in PyPI + in [dbt-athena-community](https://pypi.org/project/dbt-athena-comunity/) diff --git a/dbt/adapters/athena/__version__.py b/dbt/adapters/athena/__version__.py index a9851426..33a97d94 100644 --- a/dbt/adapters/athena/__version__.py +++ b/dbt/adapters/athena/__version__.py @@ -1 +1 @@ -version = "1.6.2" +version = "1.6.3" diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 0054e450..e0c9b2dd 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -53,6 +53,7 @@ class AthenaCredentials(Credentials): aws_profile_name: Optional[str] = None aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None + aws_session_token: Optional[str] = None poll_interval: float = 1.0 debug_query_state: bool = False _ALIASES = {"catalog": "database"} @@ -84,6 +85,7 @@ def _connection_keys(self) -> Tuple[str, ...]: "aws_profile_name", "aws_access_key_id", "aws_secret_access_key", + "aws_session_token", "endpoint_url", "s3_data_dir", "s3_data_naming", diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 8868c751..d2cc0061 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -51,11 +51,14 @@ get_catalog_id, get_catalog_type, get_chunks, + is_valid_table_parameter_key, + stringify_table_parameter_value, ) from dbt.adapters.base import ConstraintSupport, available from dbt.adapters.base.impl import AdapterConfig from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter +from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import CompiledNode, ConstraintType from dbt.exceptions import DbtRuntimeError @@ -156,8 +159,10 @@ def add_lf_tags_to_database(self, relation: AthenaRelation) -> None: LOGGER.debug(f"Lakeformation is disabled for {relation}") @available - def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any]) -> None: - config = LfTagsConfig(**lf_tags_config) + def add_lf_tags( + self, relation: AthenaRelation, lf_tags_config: Dict[str, Any], lf_inherited_tags: Optional[List[str]] + ) -> None: + config = LfTagsConfig(**(lf_tags_config | dict(inherited_tags=lf_inherited_tags))) if config.enabled: conn = self.connections.get_thread_connection() client = conn.handle @@ -564,7 +569,7 @@ def _get_one_catalog( MaxResults=50, # Limit supported by this operation ): for table in page["TableMetadataList"]: - if relations and table["Name"] in relations: + if relations and table["Name"].lower() in relations: catalog.extend( self._get_one_table_for_non_glue_catalog( table, schema, information_schema.path.database @@ -668,16 +673,28 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier ).get("Table") - src_table_partitions = glue_client.get_partitions( - CatalogId=src_catalog_id, DatabaseName=src_relation.schema, TableName=src_relation.identifier - ).get("Partitions") + src_table_get_partitions_paginator = glue_client.get_paginator("get_partitions") + src_table_partitions_result = src_table_get_partitions_paginator.paginate( + **{ + "CatalogId": src_catalog_id, + "DatabaseName": src_relation.schema, + "TableName": src_relation.identifier, + } + ) + src_table_partitions = src_table_partitions_result.build_full_result().get("Partitions") data_catalog = self._get_data_catalog(src_relation.database) target_catalog_id = get_catalog_id(data_catalog) - target_table_partitions = glue_client.get_partitions( - CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier - ).get("Partitions") + target_get_partitions_paginator = glue_client.get_paginator("get_partitions") + target_table_partitions_result = target_get_partitions_paginator.paginate( + **{ + "CatalogId": target_catalog_id, + "DatabaseName": target_relation.schema, + "TableName": target_relation.identifier, + } + ) + target_table_partitions = target_table_partitions_result.build_full_result().get("Partitions") target_table_version = { "Name": target_relation.identifier, @@ -814,11 +831,12 @@ def persist_docs_to_glue( glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) # By default, there is no need to update Glue Table - need_udpate_table = False + need_to_update_table = False # Get Table from Glue table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.name)["Table"] # Prepare new version of Glue Table picking up significant fields - updated_table = self._get_table_input(table) + table_input = self._get_table_input(table) + table_parameters = table_input["Parameters"] # Update table description if persist_relation_docs: # Prepare dbt description @@ -829,16 +847,40 @@ def persist_docs_to_glue( glue_table_comment = table["Parameters"].get("comment", "") # Update description if it's different if clean_table_description != glue_table_description or clean_table_description != glue_table_comment: - updated_table["Description"] = clean_table_description - updated_table_parameters: Dict[str, str] = dict(updated_table["Parameters"]) - updated_table_parameters["comment"] = clean_table_description - updated_table["Parameters"] = updated_table_parameters - need_udpate_table = True + table_input["Description"] = clean_table_description + table_parameters["comment"] = clean_table_description + need_to_update_table = True + + # Get dbt model meta if available + meta: Dict[str, Any] = model.get("config", {}).get("meta", {}) + # Add some of dbt model config fields as table meta + meta["unique_id"] = model.get("unique_id") + meta["materialized"] = model.get("config", {}).get("materialized") + # Get dbt runtime config to be able to get dbt project metadata + runtime_config: RuntimeConfig = self.config + # Add dbt project metadata to table meta + meta["dbt_project_name"] = runtime_config.project_name + meta["dbt_project_version"] = runtime_config.version + # Prepare meta values for table properties and check if update is required + for meta_key, meta_value_raw in meta.items(): + if is_valid_table_parameter_key(meta_key): + meta_value = stringify_table_parameter_value(meta_value_raw) + if meta_value is not None: + # Check that meta value is already attached to Glue table + current_meta_value: Optional[str] = table_parameters.get(meta_key) + if current_meta_value is None or current_meta_value != meta_value: + # Update Glue table parameter only if needed + table_parameters[meta_key] = meta_value + need_to_update_table = True + else: + LOGGER.warning(f"Meta value for key '{meta_key}' is not supported and will be ignored") + else: + LOGGER.warning(f"Meta key '{meta_key}' is not supported and will be ignored") # Update column comments if persist_column_docs: # Process every column - for col_obj in updated_table["StorageDescriptor"]["Columns"]: + for col_obj in table_input["StorageDescriptor"]["Columns"]: # Get column description from dbt col_name = col_obj["Name"] if col_name in model["columns"]: @@ -850,15 +892,16 @@ def persist_docs_to_glue( # Update column description if it's different if glue_col_comment != clean_col_comment: col_obj["Comment"] = clean_col_comment - need_udpate_table = True + need_to_update_table = True # Update Glue Table only if table/column description is modified. # It prevents redundant schema version creating after incremental runs. - if need_udpate_table: + if need_to_update_table: + table_input["Parameters"] = table_parameters glue_client.update_table( CatalogId=catalog_id, DatabaseName=relation.schema, - TableInput=updated_table, + TableInput=table_input, SkipArchive=skip_archive_table_version, ) diff --git a/dbt/adapters/athena/lakeformation.py b/dbt/adapters/athena/lakeformation.py index 9fc8047b..0804a688 100644 --- a/dbt/adapters/athena/lakeformation.py +++ b/dbt/adapters/athena/lakeformation.py @@ -1,13 +1,15 @@ """AWS Lakeformation permissions management helper utilities.""" -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Set, Union from mypy_boto3_lakeformation import LakeFormationClient from mypy_boto3_lakeformation.type_defs import ( AddLFTagsToResourceResponseTypeDef, BatchPermissionsRequestEntryTypeDef, + ColumnLFTagTypeDef, DataCellsFilterTypeDef, GetResourceLFTagsResponseTypeDef, + LFTagPairTypeDef, RemoveLFTagsFromResourceResponseTypeDef, ResourceTypeDef, ) @@ -24,6 +26,7 @@ class LfTagsConfig(BaseModel): enabled: bool = False tags: Optional[Dict[str, str]] = None tags_columns: Optional[Dict[str, Dict[str, List[str]]]] = None + inherited_tags: List[str] = [] class LfTagsManager: @@ -33,6 +36,7 @@ def __init__(self, lf_client: LakeFormationClient, relation: AthenaRelation, lf_ self.table = relation.identifier self.lf_tags = lf_tags_config.tags self.lf_tags_columns = lf_tags_config.tags_columns + self.lf_inherited_tags = set(lf_tags_config.inherited_tags) def process_lf_tags_database(self) -> None: if self.lf_tags: @@ -49,21 +53,31 @@ def process_lf_tags(self) -> None: self._apply_lf_tags_table(table_resource, existing_lf_tags) self._apply_lf_tags_columns() + @staticmethod + def _column_tags_to_remove( + lf_tags_columns: List[ColumnLFTagTypeDef], lf_inherited_tags: Set[str] + ) -> Dict[str, Dict[str, List[str]]]: + to_remove = {} + + for column in lf_tags_columns: + non_inherited_tags = [tag for tag in column["LFTags"] if not tag["TagKey"] in lf_inherited_tags] + for tag in non_inherited_tags: + tag_key = tag["TagKey"] + tag_value = tag["TagValues"][0] + if tag_key not in to_remove: + to_remove[tag_key] = {tag_value: [column["Name"]]} + elif tag_value not in to_remove[tag_key]: + to_remove[tag_key][tag_value] = [column["Name"]] + else: + to_remove[tag_key][tag_value].append(column["Name"]) + + return to_remove + def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTypeDef) -> None: lf_tags_columns = existing_lf_tags.get("LFTagsOnColumns", []) logger.debug(f"COLUMNS: {lf_tags_columns}") if lf_tags_columns: - to_remove = {} - for column in lf_tags_columns: - for tag in column["LFTags"]: - tag_key = tag["TagKey"] - tag_value = tag["TagValues"][0] - if tag_key not in to_remove: - to_remove[tag_key] = {tag_value: [column["Name"]]} - elif tag_value not in to_remove[tag_key]: - to_remove[tag_key][tag_value] = [column["Name"]] - else: - to_remove[tag_key][tag_value].append(column["Name"]) + to_remove = LfTagsManager._column_tags_to_remove(lf_tags_columns, self.lf_inherited_tags) logger.debug(f"TO REMOVE: {to_remove}") for tag_key, tag_config in to_remove.items(): for tag_value, columns in tag_config.items(): @@ -75,6 +89,17 @@ def _remove_lf_tags_columns(self, existing_lf_tags: GetResourceLFTagsResponseTyp ) self._parse_and_log_lf_response(response, columns, {tag_key: tag_value}, "remove") + @staticmethod + def _table_tags_to_remove( + lf_tags_table: List[LFTagPairTypeDef], lf_tags: Optional[Dict[str, str]], lf_inherited_tags: Set[str] + ) -> Dict[str, Sequence[str]]: + return { + tag["TagKey"]: tag["TagValues"] + for tag in lf_tags_table + if tag["TagKey"] not in (lf_tags or {}) + if tag["TagKey"] not in lf_inherited_tags + } + def _apply_lf_tags_table( self, table_resource: ResourceTypeDef, existing_lf_tags: GetResourceLFTagsResponseTypeDef ) -> None: @@ -82,11 +107,8 @@ def _apply_lf_tags_table( logger.debug(f"EXISTING TABLE TAGS: {lf_tags_table}") logger.debug(f"CONFIG TAGS: {self.lf_tags}") - to_remove = { - tag["TagKey"]: tag["TagValues"] - for tag in lf_tags_table - if tag["TagKey"] not in self.lf_tags # type: ignore - } + to_remove = LfTagsManager._table_tags_to_remove(lf_tags_table, self.lf_tags, self.lf_inherited_tags) + logger.debug(f"TAGS TO REMOVE: {to_remove}") if to_remove: response = self.lf_client.remove_lf_tags_from_resource( diff --git a/dbt/adapters/athena/session.py b/dbt/adapters/athena/session.py index 60c86fba..cea35a16 100644 --- a/dbt/adapters/athena/session.py +++ b/dbt/adapters/athena/session.py @@ -7,6 +7,7 @@ def get_boto3_session(connection: Connection) -> boto3.session.Session: return boto3.session.Session( aws_access_key_id=connection.credentials.aws_access_key_id, aws_secret_access_key=connection.credentials.aws_secret_access_key, + aws_session_token=connection.credentials.aws_session_token, region_name=connection.credentials.region_name, profile_name=connection.credentials.aws_profile_name, ) diff --git a/dbt/adapters/athena/utils.py b/dbt/adapters/athena/utils.py index dcd74916..0922f6da 100644 --- a/dbt/adapters/athena/utils.py +++ b/dbt/adapters/athena/utils.py @@ -1,14 +1,40 @@ +import json +import re from enum import Enum -from typing import Generator, List, Optional, TypeVar +from typing import Any, Generator, List, Optional, TypeVar from mypy_boto3_athena.type_defs import DataCatalogTypeDef +from dbt.adapters.athena.constants import LOGGER + def clean_sql_comment(comment: str) -> str: split_and_strip = [line.strip() for line in comment.split("\n")] return " ".join(line for line in split_and_strip if line) +def stringify_table_parameter_value(value: Any) -> Optional[str]: + """Convert any variable to string for Glue Table property.""" + try: + if isinstance(value, (dict, list)): + value_str: str = json.dumps(value) + else: + value_str = str(value) + return value_str[:512000] + except (TypeError, ValueError) as e: + # Handle non-stringifiable objects and non-serializable objects + LOGGER.warning(f"Non-stringifiable object. Error: {str(e)}") + return None + + +def is_valid_table_parameter_key(key: str) -> bool: + """Check if key is valid for Glue Table property according to official documentation.""" + # Simplified version of key pattern which works with re + # Original pattern can be found here https://docs.aws.amazon.com/glue/latest/webapi/API_Table.html + key_pattern: str = r"^[\u0020-\uD7FF\uE000-\uFFFD\t]*$" + return len(key) <= 255 and bool(re.match(key_pattern, key)) + + def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None diff --git a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql index 76c2ed73..e9de7a80 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql @@ -70,7 +70,7 @@ {%- set single_partition = [] -%} {%- for col in row -%} {%- set column_type = adapter.convert_type(table, loop.index0) -%} - {%- if column_type == 'integer' -%} + {%- if column_type == 'integer' or column_type is none -%} {%- set value = col|string -%} {%- elif column_type == 'string' -%} {%- set value = "'" + col + "'" -%} diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index a5ed6812..11350ee7 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -6,6 +6,7 @@ {% set on_schema_change = incremental_validate_on_schema_change(config.get('on_schema_change'), default='ignore') %} {% set lf_tags_config = config.get('lf_tags_config') %} + {% set lf_inherited_tags = config.get('lf_inherited_tags') %} {% set lf_grants = config.get('lf_grants') %} {% set partitioned_by = config.get('partitioned_by') %} {% set target_relation = this.incorporate(type='table') %} @@ -106,7 +107,7 @@ {{ run_hooks(post_hooks, inside_transaction=False) }} {% if lf_tags_config is not none %} - {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {{ adapter.add_lf_tags(target_relation, lf_tags_config, lf_inherited_tags) }} {% endif %} {% if lf_grants is not none %} diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index 989bf63b..7df54cec 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -3,6 +3,7 @@ {%- set identifier = model['alias'] -%} {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_inherited_tags = config.get('lf_inherited_tags') -%} {%- set lf_grants = config.get('lf_grants') -%} {%- set table_type = config.get('table_type', default='hive') | lower -%} @@ -111,7 +112,7 @@ {{ run_hooks(post_hooks) }} {% if lf_tags_config is not none %} - {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {{ adapter.add_lf_tags(target_relation, lf_tags_config, lf_inherited_tags) }} {% endif %} {% if lf_grants is not none %} diff --git a/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql b/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql index ae787a81..7e4c8b2a 100644 --- a/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql +++ b/dbt/include/athena/macros/materializations/models/view/create_or_replace_view.sql @@ -2,6 +2,7 @@ {%- set identifier = model['alias'] -%} {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_inherited_tags = config.get('lf_inherited_tags') -%} {%- set lf_grants = config.get('lf_grants') -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} @@ -32,7 +33,7 @@ {%- endcall %} {% if lf_tags_config is not none %} - {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {{ adapter.add_lf_tags(target_relation, lf_tags_config, lf_inherited_tags) }} {% endif %} {% if lf_grants is not none %} diff --git a/dbt/include/athena/macros/materializations/seeds/helpers.sql b/dbt/include/athena/macros/materializations/seeds/helpers.sql index a789498f..6f1daaba 100644 --- a/dbt/include/athena/macros/materializations/seeds/helpers.sql +++ b/dbt/include/athena/macros/materializations/seeds/helpers.sql @@ -91,6 +91,7 @@ {%- set identifier = model['alias'] -%} {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_inherited_tags = config.get('lf_inherited_tags') -%} {%- set lf_grants = config.get('lf_grants') -%} {%- set column_override = config.get('column_types', {}) -%} @@ -179,7 +180,7 @@ {% do adapter.delete_from_s3(tmp_s3_location) %} {% if lf_tags_config is not none %} - {{ adapter.add_lf_tags(relation, lf_tags_config) }} + {{ adapter.add_lf_tags(relation, lf_tags_config, lf_inherited_tags) }} {% endif %} {% if lf_grants is not none %} diff --git a/dbt/include/athena/macros/materializations/snapshots/snapshot.sql b/dbt/include/athena/macros/materializations/snapshots/snapshot.sql index 08d41c57..0ad26448 100644 --- a/dbt/include/athena/macros/materializations/snapshots/snapshot.sql +++ b/dbt/include/athena/macros/materializations/snapshots/snapshot.sql @@ -128,6 +128,7 @@ {%- set table_type = config.get('table_type', 'hive') -%} {%- set lf_tags_config = config.get('lf_tags_config') -%} + {%- set lf_inherited_tags = config.get('lf_inherited_tags') -%} {%- set lf_grants = config.get('lf_grants') -%} {{ log('Checking if target table exists') }} @@ -230,7 +231,7 @@ {{ run_hooks(post_hooks, inside_transaction=False) }} {% if lf_tags_config is not none %} - {{ adapter.add_lf_tags(target_relation, lf_tags_config) }} + {{ adapter.add_lf_tags(target_relation, lf_tags_config, lf_inherited_tags) }} {% endif %} {% if lf_grants is not none %} diff --git a/dev-requirements.txt b/dev-requirements.txt index 1e481fff..3cd122b4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,12 +1,12 @@ autoflake~=1.7 black~=23.9 boto3-stubs[s3]~=1.28 -dbt-tests-adapter~=1.6.5 +dbt-tests-adapter~=1.6.6 flake8~=6.1 Flake8-pyproject~=1.2 isort~=5.11 -moto~=4.2.5 -pre-commit~=3.4 +moto~=4.2.6 +pre-commit~=3.5 pyparsing~=3.1.1 pytest~=7.4 pytest-cov~=4.1 diff --git a/tests/unit/test_lakeformation.py b/tests/unit/test_lakeformation.py index ab061c09..3a05030c 100644 --- a/tests/unit/test_lakeformation.py +++ b/tests/unit/test_lakeformation.py @@ -2,6 +2,7 @@ import pytest from tests.unit.constants import AWS_REGION, DATA_CATALOG_NAME, DATABASE_NAME +import dbt.adapters.athena.lakeformation as lakeformation from dbt.adapters.athena.lakeformation import LfTagsConfig, LfTagsManager from dbt.adapters.athena.relation import AthenaRelation @@ -74,3 +75,70 @@ def test__parse_lf_response(self, dbt_debug_caplog, response, identifier, column manager = LfTagsManager(lf_client, relation, LfTagsConfig()) manager._parse_and_log_lf_response(response, columns, lf_tags, verb) assert expected in dbt_debug_caplog.getvalue() + + @pytest.mark.parametrize( + "lf_tags_columns,lf_inherited_tags,expected", + [ + pytest.param( + [{"Name": "my_column", "LFTags": [{"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}]}], + {"inherited"}, + {}, + id="retains-inherited-tag", + ), + pytest.param( + [{"Name": "my_column", "LFTags": [{"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}]}], + {}, + {"not-inherited": {"oh-no-im-not": ["my_column"]}}, + id="removes-non-inherited-tag", + ), + pytest.param( + [ + { + "Name": "my_column", + "LFTags": [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + } + ], + {"inherited"}, + {"not-inherited": {"oh-no-im-not": ["my_column"]}}, + id="removes-non-inherited-tag-among-inherited", + ), + pytest.param([], {}, {}, id="handles-empty"), + ], + ) + def test__column_tags_to_remove(self, lf_tags_columns, lf_inherited_tags, expected): + assert lakeformation.LfTagsManager._column_tags_to_remove(lf_tags_columns, lf_inherited_tags) == expected + + @pytest.mark.parametrize( + "lf_tags_table,lf_tags,lf_inherited_tags,expected", + [ + pytest.param( + [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + {"not-inherited": "some-preexisting-value"}, + {"inherited"}, + {}, + id="retains-being-set-and-inherited", + ), + pytest.param( + [ + {"TagKey": "not-inherited", "TagValues": ["oh-no-im-not"]}, + {"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}, + ], + {}, + {"inherited"}, + {"not-inherited": ["oh-no-im-not"]}, + id="removes-preexisting-not-being-set", + ), + pytest.param( + [{"TagKey": "inherited", "TagValues": ["oh-yes-i-am"]}], {}, {"inherited"}, {}, id="retains-inherited" + ), + pytest.param([], None, {}, {}, id="handles-empty"), + ], + ) + def test__table_tags_to_remove(self, lf_tags_table, lf_tags, lf_inherited_tags, expected): + assert lakeformation.LfTagsManager._table_tags_to_remove(lf_tags_table, lf_tags, lf_inherited_tags) == expected diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 23851360..543558e4 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,9 @@ -from dbt.adapters.athena.utils import clean_sql_comment, get_chunks +from dbt.adapters.athena.utils import ( + clean_sql_comment, + get_chunks, + is_valid_table_parameter_key, + stringify_table_parameter_value, +) def test_clean_comment(): @@ -14,6 +19,28 @@ def test_clean_comment(): ) +def test_stringify_table_parameter_value(): + class NonStringifiableObject: + def __str__(self): + raise ValueError("Non-stringifiable object") + + assert stringify_table_parameter_value(True) == "True" + assert stringify_table_parameter_value(123) == "123" + assert stringify_table_parameter_value("dbt-athena") == "dbt-athena" + assert stringify_table_parameter_value(["a", "b", 3]) == '["a", "b", 3]' + assert stringify_table_parameter_value({"a": 1, "b": "c"}) == '{"a": 1, "b": "c"}' + assert len(stringify_table_parameter_value("a" * 512001)) == 512000 + assert stringify_table_parameter_value(NonStringifiableObject()) is None + assert stringify_table_parameter_value([NonStringifiableObject()]) is None + + +def test_is_valid_table_parameter_key(): + assert is_valid_table_parameter_key("valid_key") is True + assert is_valid_table_parameter_key("Valid Key 123*!") is True + assert is_valid_table_parameter_key("invalid \n key") is False + assert is_valid_table_parameter_key("long_key" * 100) is False + + def test_get_chunks_empty(): assert len(list(get_chunks([], 5))) == 0