diff --git a/dbt-athena/src/dbt/adapters/athena/impl.py b/dbt-athena/src/dbt/adapters/athena/impl.py index d230f471..5b94a4c9 100755 --- a/dbt-athena/src/dbt/adapters/athena/impl.py +++ b/dbt-athena/src/dbt/adapters/athena/impl.py @@ -1227,8 +1227,7 @@ def _generate_snapshot_migration_sql(self, relation: AthenaRelation, table_colum """ ) staging_sql = self.execute_macro( - "create_table_as", - kwargs=dict(temporary=True, relation=staging_relation, compiled_code=ctas), + "create_table_as", kwargs=dict(temporary=True, relation=staging_relation, compiled_code=ctas) ) backup_relation = relation.incorporate(path={"identifier": relation.identifier + "__dbt_tmp_migration_backup"}) diff --git a/dbt-athena/tests/unit/test_adapter.py b/dbt-athena/tests/unit/test_adapter.py index df171812..e3c3b923 100644 --- a/dbt-athena/tests/unit/test_adapter.py +++ b/dbt-athena/tests/unit/test_adapter.py @@ -135,10 +135,7 @@ def test_acquire_connection_exc(self, connection_cls, dbt_error_caplog): assert conn_res is None assert connection.state == ConnectionState.FAIL assert exc.value.__str__() == "foobar" - assert ( - "Got an error when attempting to open a Athena connection due to foobar" - in dbt_error_caplog.getvalue() - ) + assert "Got an error when attempting to open a Athena connection due to foobar" in dbt_error_caplog.getvalue() @pytest.mark.parametrize( ( @@ -152,24 +149,10 @@ def test_acquire_connection_exc(self, connection_cls, dbt_error_caplog): ), ( pytest.param( - None, - "table", - None, - None, - None, - False, - "s3://my-bucket/test-dbt/tables/table", - id="table naming", + None, "table", None, None, None, False, "s3://my-bucket/test-dbt/tables/table", id="table naming" ), pytest.param( - None, - "unique", - None, - None, - None, - False, - "s3://my-bucket/test-dbt/tables/uuid", - id="unique naming", + None, "unique", None, None, None, False, "s3://my-bucket/test-dbt/tables/uuid", id="unique naming" ), pytest.param( None, @@ -273,12 +256,7 @@ def test_generate_s3_location( s3_path_table_part=s3_path_table_part, ) assert expected == self.adapter.generate_s3_location( - relation, - s3_data_dir, - s3_data_naming, - s3_tmp_table_dir, - external_location, - is_temporary_table, + relation, s3_data_dir, s3_data_naming, s3_tmp_table_dir, external_location, is_temporary_table ) @mock_aws @@ -293,15 +271,10 @@ def test_get_table_location(self, dbt_debug_caplog, mock_aws_service): schema=DATABASE_NAME, identifier=table_name, ) - assert ( - self.adapter.get_glue_table_location(relation) - == "s3://test-dbt-athena/tables/test_table" - ) + assert self.adapter.get_glue_table_location(relation) == "s3://test-dbt-athena/tables/test_table" @mock_aws - def test_get_table_location_raise_s3_location_exception( - self, dbt_debug_caplog, mock_aws_service - ): + def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service): table_name = "test_table" self.adapter.acquire_connection("dummy") mock_aws_service.create_data_catalog() @@ -327,10 +300,7 @@ def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service): mock_aws_service.create_database() mock_aws_service.create_view(view_name) relation = self.adapter.Relation.create( - database=DATA_CATALOG_NAME, - schema=DATABASE_NAME, - identifier=view_name, - type=RelationType.View, + database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier=view_name, type=RelationType.View ) assert self.adapter.get_glue_table_location(relation) is None @@ -346,9 +316,7 @@ def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_servic identifier=table_name, ) assert self.adapter.get_glue_table_location(relation) is None - assert ( - f"Table {relation.render()} does not exists - Ignoring" in dbt_debug_caplog.getvalue() - ) + assert f"Table {relation.render()} does not exists - Ignoring" in dbt_debug_caplog.getvalue() @mock_aws def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service): @@ -379,10 +347,7 @@ def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service) ) s3 = boto3.client("s3", region_name=AWS_REGION) keys = [obj["Key"] for obj in s3.list_objects_v2(Bucket=BUCKET)["Contents"]] - assert set(keys) == { - "tables/table/dt=2022-01-03/data1.parquet", - "tables/table/dt=2022-01-03/data2.parquet", - } + assert set(keys) == {"tables/table/dt=2022-01-03/data1.parquet", "tables/table/dt=2022-01-03/data2.parquet"} @mock_aws def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service): @@ -397,8 +362,7 @@ def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_se result = self.adapter.clean_up_table(relation) assert result is None assert ( - 'Table "awsdatacatalog"."test_dbt_athena"."table" does not exists - Ignoring' - in dbt_debug_caplog.getvalue() + 'Table "awsdatacatalog"."test_dbt_athena"."table" does not exists - Ignoring' in dbt_debug_caplog.getvalue() ) @mock_aws @@ -462,9 +426,7 @@ def test__get_one_catalog(self, mock_aws_service): mock_information_schema.database = "awsdatacatalog" self.adapter.acquire_connection("dummy") - actual = self.adapter._get_one_catalog( - mock_information_schema, {"foo", "quux"}, self.used_schemas - ) + actual = self.adapter._get_one_catalog(mock_information_schema, {"foo", "quux"}, self.used_schemas) expected_column_names = ( "table_database", @@ -532,21 +494,15 @@ def test__get_one_catalog_by_relations(self, mock_aws_service): ("awsdatacatalog", "foo", "bar", "table", None, "dt", 2, "date", None), ] - actual = self.adapter._get_one_catalog_by_relations( - mock_information_schema, [rel_1], self.used_schemas - ) + actual = self.adapter._get_one_catalog_by_relations(mock_information_schema, [rel_1], self.used_schemas) assert actual.column_names == expected_column_names assert actual.rows == expected_rows @mock_aws def test__get_one_catalog_shared_catalog(self, mock_aws_service): - mock_aws_service.create_data_catalog( - catalog_name=SHARED_DATA_CATALOG_NAME, catalog_id=SHARED_DATA_CATALOG_NAME - ) + mock_aws_service.create_data_catalog(catalog_name=SHARED_DATA_CATALOG_NAME, catalog_id=SHARED_DATA_CATALOG_NAME) mock_aws_service.create_database("foo", catalog_id=SHARED_DATA_CATALOG_NAME) - mock_aws_service.create_table( - table_name="bar", database_name="foo", catalog_id=SHARED_DATA_CATALOG_NAME - ) + mock_aws_service.create_table(table_name="bar", database_name="foo", catalog_id=SHARED_DATA_CATALOG_NAME) mock_information_schema = mock.MagicMock() mock_information_schema.database = SHARED_DATA_CATALOG_NAME @@ -622,9 +578,7 @@ def mock_athena_list_table_metadata(self, operation_name, kwarg): return orig(self, operation_name, kwarg) self.adapter.acquire_connection("dummy") - with patch( - "botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata - ): + with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata): actual = self.adapter._get_one_catalog( mock_information_schema, {"foo"}, @@ -644,17 +598,7 @@ def mock_athena_list_table_metadata(self, operation_name, kwarg): ) expected_rows = [ (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None), - ( - FEDERATED_QUERY_CATALOG_NAME, - "foo", - "bar", - "table", - None, - "country", - 1, - "string", - None, - ), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None), (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None), ] @@ -668,11 +612,7 @@ def test__get_data_catalog(self, mock_aws_service): mock_aws_service.create_data_catalog() self.adapter.acquire_connection("dummy") res = self.adapter._get_data_catalog(DATA_CATALOG_NAME) - assert { - "Name": "awsdatacatalog", - "Type": "GLUE", - "Parameters": {"catalog-id": DEFAULT_ACCOUNT_ID}, - } == res + assert {"Name": "awsdatacatalog", "Type": "GLUE", "Parameters": {"catalog-id": DEFAULT_ACCOUNT_ID}} == res def _test_list_relations_without_caching(self, schema_relation): self.adapter.acquire_connection("dummy") @@ -760,14 +700,8 @@ def test_list_relations_without_caching_with_non_glue_data_catalog( @pytest.mark.parametrize( "s3_path,expected", [ - ( - "s3://my-bucket/test-dbt/tables/schema/table", - ("my-bucket", "test-dbt/tables/schema/table/"), - ), - ( - "s3://my-bucket/test-dbt/tables/schema/table/", - ("my-bucket", "test-dbt/tables/schema/table/"), - ), + ("s3://my-bucket/test-dbt/tables/schema/table", ("my-bucket", "test-dbt/tables/schema/table/")), + ("s3://my-bucket/test-dbt/tables/schema/table/", ("my-bucket", "test-dbt/tables/schema/table/")), ], ) def test_parse_s3_path(self, s3_path, expected): @@ -795,10 +729,7 @@ def test_swap_table_with_partitions(self, mock_aws_service): identifier=target_table, ) self.adapter.swap_table(source_relation, target_relation) - assert ( - self.adapter.get_glue_table_location(target_relation) - == f"s3://{BUCKET}/tables/{source_table}" - ) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" @mock_aws def test_swap_table_without_partitions(self, mock_aws_service): @@ -820,10 +751,7 @@ def test_swap_table_without_partitions(self, mock_aws_service): identifier=target_table, ) self.adapter.swap_table(source_relation, target_relation) - assert ( - self.adapter.get_glue_table_location(target_relation) - == f"s3://{BUCKET}/tables/{source_table}" - ) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" @mock_aws def test_swap_table_with_partitions_to_one_without(self, mock_aws_service): @@ -853,14 +781,11 @@ def test_swap_table_with_partitions_to_one_without(self, mock_aws_service): self.adapter.swap_table(source_relation, target_relation) glue_client = boto3.client("glue", region_name=AWS_REGION) - target_table_partitions = glue_client.get_partitions( - DatabaseName=DATABASE_NAME, TableName=target_table - ).get("Partitions") - - assert ( - self.adapter.get_glue_table_location(target_relation) - == f"s3://{BUCKET}/tables/{source_table}" + target_table_partitions = glue_client.get_partitions(DatabaseName=DATABASE_NAME, TableName=target_table).get( + "Partitions" ) + + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" assert len(target_table_partitions) == 0 @mock_aws @@ -874,9 +799,9 @@ def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service): mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) mock_aws_service.create_table_without_partitions(target_table) glue_client = boto3.client("glue", region_name=AWS_REGION) - target_table_partitions = glue_client.get_partitions( - DatabaseName=DATABASE_NAME, TableName=target_table - ).get("Partitions") + target_table_partitions = glue_client.get_partitions(DatabaseName=DATABASE_NAME, TableName=target_table).get( + "Partitions" + ) assert len(target_table_partitions) == 0 source_relation = self.adapter.Relation.create( database=DATA_CATALOG_NAME, @@ -893,10 +818,7 @@ def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service): DatabaseName=DATABASE_NAME, TableName=target_table ).get("Partitions") - assert ( - self.adapter.get_glue_table_location(target_relation) - == f"s3://{BUCKET}/tables/{source_table}" - ) + assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" assert len(target_table_partitions_after) == 26 @mock_aws @@ -910,9 +832,7 @@ def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_ca mock_aws_service.add_table_version(DATABASE_NAME, table_name) mock_aws_service.add_table_version(DATABASE_NAME, table_name) glue = boto3.client("glue", region_name=AWS_REGION) - table_versions = glue.get_table_versions( - DatabaseName=DATABASE_NAME, TableName=table_name - ).get("TableVersions") + table_versions = glue.get_table_versions(DatabaseName=DATABASE_NAME, TableName=table_name).get("TableVersions") assert len(table_versions) == 4 version_to_keep = 1 relation = self.adapter.Relation.create( @@ -920,9 +840,7 @@ def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_ca schema=DATABASE_NAME, identifier=table_name, ) - versions_to_expire = self.adapter._get_glue_table_versions_to_expire( - relation, version_to_keep - ) + versions_to_expire = self.adapter._get_glue_table_versions_to_expire(relation, version_to_keep) assert len(versions_to_expire) == 3 assert [v["VersionId"] for v in versions_to_expire] == ["3", "2", "1"] @@ -937,9 +855,7 @@ def test_expire_glue_table_versions(self, mock_aws_service): mock_aws_service.add_table_version(DATABASE_NAME, table_name) mock_aws_service.add_table_version(DATABASE_NAME, table_name) glue = boto3.client("glue", region_name=AWS_REGION) - table_versions = glue.get_table_versions( - DatabaseName=DATABASE_NAME, TableName=table_name - ).get("TableVersions") + table_versions = glue.get_table_versions(DatabaseName=DATABASE_NAME, TableName=table_name).get("TableVersions") assert len(table_versions) == 4 version_to_keep = 1 relation = self.adapter.Relation.create( @@ -961,9 +877,7 @@ def test_upload_seed_to_s3(self, mock_aws_service): table = "data" s3_client = boto3.client("s3", region_name=AWS_REGION) - s3_client.create_bucket( - Bucket=BUCKET, CreateBucketConfiguration={"LocationConstraint": AWS_REGION} - ) + s3_client.create_bucket(Bucket=BUCKET, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) relation = self.adapter.Relation.create( database=DATA_CATALOG_NAME, @@ -996,9 +910,7 @@ def test_upload_seed_to_s3_external_location(self, mock_aws_service): external_location = f"s3://{bucket}/{prefix}" s3_client = boto3.client("s3", region_name=AWS_REGION) - s3_client.create_bucket( - Bucket=bucket, CreateBucketConfiguration={"LocationConstraint": AWS_REGION} - ) + s3_client.create_bucket(Bucket=bucket, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) relation = self.adapter.Relation.create( database=DATA_CATALOG_NAME, @@ -1129,13 +1041,8 @@ def test_persist_docs_to_glue_comment(self, mock_aws_service): ) glue = boto3.client("glue", region_name=AWS_REGION) table = glue.get_table(DatabaseName=DATABASE_NAME, Name=table_name).get("Table") - assert ( - table["Description"] == "A table with str, 123, &^% \" and ' and an other paragraph." - ) - assert ( - table["Parameters"]["comment"] - == "A table with str, 123, &^% \" and ' and an other paragraph." - ) + assert table["Description"] == "A table with str, 123, &^% \" and ' and an other paragraph." + assert table["Parameters"]["comment"] == "A table with str, 123, &^% \" and ' and an other paragraph." col_id = [col for col in table["StorageDescriptor"]["Columns"] if col["Name"] == "id"][0] assert col_id["Comment"] == "A column with str, 123, &^% \" and ' and an other paragraph." assert col_id["Parameters"] == {"primary_key": "True"} @@ -1189,9 +1096,7 @@ def test_delete_from_glue_catalog(self, mock_aws_service): mock_aws_service.create_database() mock_aws_service.create_table("tbl_name") self.adapter.acquire_connection("dummy") - relation = self.adapter.Relation.create( - database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name" - ) + relation = self.adapter.Relation.create(database=DATA_CATALOG_NAME, schema=DATABASE_NAME, identifier="tbl_name") self.adapter.delete_from_glue_catalog(relation) glue = boto3.client("glue", region_name=AWS_REGION) tables_list = glue.get_tables(DatabaseName=DATABASE_NAME).get("TableList") @@ -1262,22 +1167,8 @@ def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service): @pytest.mark.parametrize( "column,expected", [ - pytest.param( - { - "Name": "user_id", - "Type": "int", - "Parameters": {"iceberg.field.current": "true"}, - }, - True, - ), - pytest.param( - { - "Name": "user_id", - "Type": "int", - "Parameters": {"iceberg.field.current": "false"}, - }, - False, - ), + pytest.param({"Name": "user_id", "Type": "int", "Parameters": {"iceberg.field.current": "true"}}, True), + pytest.param({"Name": "user_id", "Type": "int", "Parameters": {"iceberg.field.current": "false"}}, False), pytest.param({"Name": "user_id", "Type": "int"}, True), ], ) @@ -1422,9 +1313,7 @@ def test_convert_datetime_type(self): ["", "20190102T01:01:01Z", "2019-01-01 01:01:01"], ["", "20190103T01:01:01Z", "2019-01-01 01:01:01"], ] - agate_table = self._make_table_of( - rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime] - ) + agate_table = self._make_table_of(rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime]) expected = ["timestamp", "timestamp", "timestamp"] for col_idx, expect in enumerate(expected): assert AthenaAdapter.convert_datetime_type(agate_table, col_idx) == expect