Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cross account catalog_id glue client function calls #370

Merged
merged 10 commits into from
Sep 7, 2023
57 changes: 47 additions & 10 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,15 @@ def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseT
"""
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)
except ClientError as e:
if e.response["Error"]["Code"] == "EntityNotFoundException":
LOGGER.debug(f"Table {relation.render()} does not exists - Ignoring")
Expand Down Expand Up @@ -596,16 +600,25 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(src_relation.database)
src_catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

src_table = glue_client.get_table(DatabaseName=src_relation.schema, Name=src_relation.identifier).get("Table")
src_table = glue_client.get_table(
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier
).get("Table")

src_table_partitions = glue_client.get_partitions(
DatabaseName=src_relation.schema, TableName=src_relation.identifier
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, TableName=src_relation.identifier
).get("Partitions")

data_catalog = self._get_data_catalog(src_relation.database)
target_catalog_id = get_catalog_id(data_catalog)

target_table_partitions = glue_client.get_partitions(
DatabaseName=target_relation.schema, TableName=target_relation.identifier
CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableName=target_relation.identifier
).get("Partitions")

target_table_version = {
Expand All @@ -618,7 +631,9 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
}

# perform a table swap
glue_client.update_table(DatabaseName=target_relation.schema, TableInput=target_table_version)
glue_client.update_table(
CatalogId=target_catalog_id, DatabaseName=target_relation.schema, TableInput=target_table_version
)
LOGGER.debug(f"Table {target_relation.render()} swapped with the content of {src_relation.render()}")

# we delete the target table partitions in any case
Expand All @@ -627,6 +642,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
if target_table_partitions:
for partition_batch in get_chunks(target_table_partitions, AthenaAdapter.BATCH_DELETE_PARTITION_API_LIMIT):
glue_client.batch_delete_partition(
CatalogId=target_catalog_id,
DatabaseName=target_relation.schema,
TableName=target_relation.identifier,
PartitionsToDelete=[{"Values": partition["Values"]} for partition in partition_batch],
Expand All @@ -635,6 +651,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati
if src_table_partitions:
for partition_batch in get_chunks(src_table_partitions, AthenaAdapter.BATCH_CREATE_PARTITION_API_LIMIT):
glue_client.batch_create_partition(
CatalogId=target_catalog_id,
DatabaseName=target_relation.schema,
TableName=target_relation.identifier,
PartitionInputList=[
Expand Down Expand Up @@ -676,6 +693,9 @@ def expire_glue_table_versions(
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

Expand All @@ -688,7 +708,10 @@ def expire_glue_table_versions(
location = v["Table"]["StorageDescriptor"]["Location"]
try:
glue_client.delete_table_version(
DatabaseName=relation.schema, TableName=relation.identifier, VersionId=str(version)
CatalogId=catalog_id,
DatabaseName=relation.schema,
TableName=relation.identifier,
VersionId=str(version),
)
deleted_versions.append(version)
LOGGER.debug(f"Deleted version {version} of table {relation.render()} ")
Expand Down Expand Up @@ -720,13 +743,16 @@ def persist_docs_to_glue(
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

# By default, there is no need to update Glue Table
need_udpate_table = False
# Get Table from Glue
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.name)["Table"]
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.name)["Table"]
# Prepare new version of Glue Table picking up significant fields
updated_table = self._get_table_input(table)
# Update table description
Expand Down Expand Up @@ -766,7 +792,10 @@ def persist_docs_to_glue(
# It prevents redundant schema version creating after incremental runs.
if need_udpate_table:
glue_client.update_table(
DatabaseName=relation.schema, TableInput=updated_table, SkipArchive=skip_archive_table_version
CatalogId=catalog_id,
DatabaseName=relation.schema,
TableInput=updated_table,
SkipArchive=skip_archive_table_version,
)

@available
Expand Down Expand Up @@ -797,11 +826,16 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)["Table"]
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)[
"Table"
]
except ClientError as e:
if e.response["Error"]["Code"] == "EntityNotFoundException":
LOGGER.debug("table not exist, catching the error")
Expand Down Expand Up @@ -829,11 +863,14 @@ def delete_from_glue_catalog(self, relation: AthenaRelation) -> None:
conn = self.connections.get_thread_connection()
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
glue_client.delete_table(DatabaseName=schema_name, Name=table_name)
glue_client.delete_table(CatalogId=catalog_id, DatabaseName=schema_name, Name=table_name)
LOGGER.debug(f"Deleted table from glue catalog: {relation.render()}")
except ClientError as e:
if e.response["Error"]["Code"] == "EntityNotFoundException":
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def test_generate_s3_location(
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location(self, dbt_debug_caplog, mock_aws_service):
table_name = "test_table"
self.adapter.acquire_connection("dummy")
Expand All @@ -417,6 +418,7 @@ def test_get_table_location(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog, mock_aws_service):
table_name = "test_table"
self.adapter.acquire_connection("dummy")
Expand All @@ -438,6 +440,7 @@ def test_get_table_location_raise_s3_location_exception(self, dbt_debug_caplog,
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service):
view_name = "view"
self.adapter.acquire_connection("dummy")
Expand All @@ -452,6 +455,7 @@ def test_get_table_location_for_view(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_get_table_location_with_failure(self, dbt_debug_caplog, mock_aws_service):
table_name = "test_table"
self.adapter.acquire_connection("dummy")
Expand Down Expand Up @@ -500,6 +504,7 @@ def test_clean_up_partitions_will_work(self, dbt_debug_caplog, mock_aws_service)

@mock_glue
@mock_athena
@mock_sts
def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -517,6 +522,7 @@ def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, mock_aws_se

@mock_glue
@mock_athena
@mock_sts
def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -534,6 +540,7 @@ def test_clean_up_table_view(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test_clean_up_table_delete_table(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -844,6 +851,7 @@ def test_parse_s3_path(self, s3_path, expected):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_with_partitions(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -870,6 +878,7 @@ def test_swap_table_with_partitions(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_without_partitions(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -894,6 +903,7 @@ def test_swap_table_without_partitions(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_with_partitions_to_one_without(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -931,6 +941,7 @@ def test_swap_table_with_partitions_to_one_without(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -990,6 +1001,7 @@ def test__get_glue_table_versions_to_expire(self, mock_aws_service, dbt_debug_ca
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_expire_glue_table_versions(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -1101,6 +1113,7 @@ def test_get_work_group_output_location_not_enforced(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_persist_docs_to_glue_no_comment(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -1142,6 +1155,7 @@ def test_persist_docs_to_glue_no_comment(self, mock_aws_service):
@mock_athena
@mock_glue
@mock_s3
@mock_sts
def test_persist_docs_to_glue_comment(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down Expand Up @@ -1194,6 +1208,7 @@ def test_list_schemas(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_get_columns_in_relation(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1214,6 +1229,7 @@ def test_get_columns_in_relation(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_get_columns_in_relation_not_found_table(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1229,6 +1245,7 @@ def test_get_columns_in_relation_not_found_table(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_delete_from_glue_catalog(self, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1242,6 +1259,7 @@ def test_delete_from_glue_catalog(self, mock_aws_service):

@mock_athena
@mock_glue
@mock_sts
def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1258,6 +1276,7 @@ def test_delete_from_glue_catalog_not_found_table(self, dbt_debug_caplog, mock_a
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1272,6 +1291,7 @@ def test__get_relation_type_table(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1286,6 +1306,7 @@ def test__get_relation_type_with_no_type(self, dbt_debug_caplog, mock_aws_servic
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand All @@ -1300,6 +1321,7 @@ def test__get_relation_type_view(self, dbt_debug_caplog, mock_aws_service):
@mock_glue
@mock_s3
@mock_athena
@mock_sts
def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service):
mock_aws_service.create_data_catalog()
mock_aws_service.create_database()
Expand Down