Skip to content

Commit

Permalink
Merge pull request #268 from moomindani/refine_tests
Browse files Browse the repository at this point in the history
[test] Refine unit test and integ tests
  • Loading branch information
menuetb authored Nov 10, 2023
2 parents bf23d5e + c6ed306 commit d095ec8
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 45 deletions.
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ black==23.11.0
# Adapter specific dependencies
waiter
boto3
moto~=4.2.7
pyparsing

dbt-core~=1.7.1
dbt-spark~=1.7.1
Expand Down
49 changes: 7 additions & 42 deletions tests/functional/adapter/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest

import boto3
import os
from urllib.parse import urlparse
from dbt.tests.adapter.basic.test_base import BaseSimpleMaterializations
from dbt.tests.adapter.basic.test_singular_tests import BaseSingularTests
from dbt.tests.adapter.basic.test_singular_tests_ephemeral import BaseSingularTestsEphemeral
Expand All @@ -29,7 +27,7 @@
check_relations_equal,
)

from tests.util import get_s3_location, get_region
from tests.util import get_s3_location, get_region, cleanup_s3_location


s3bucket = get_s3_location()
Expand Down Expand Up @@ -61,11 +59,6 @@
base_materialized_var_sql = config_materialized_var + config_incremental_strategy + model_base


def cleanup_s3_location():
client = boto3.client("s3", region_name=region)
S3Url(s3bucket + schema_name).delete_all_keys_v2(client)


class TestSimpleMaterializationsGlue(BaseSimpleMaterializations):
# all tests within this test has the same schema
@pytest.fixture(scope="class")
Expand All @@ -92,7 +85,7 @@ def models(self):

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

pass
Expand Down Expand Up @@ -131,7 +124,7 @@ def models(self):

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

# test_ephemeral with refresh table
Expand Down Expand Up @@ -184,7 +177,7 @@ def unique_schema(request, prefix) -> str:
class TestIncrementalGlue(BaseIncremental):
@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -250,16 +243,18 @@ def unique_schema(request, prefix) -> str:

@pytest.fixture(scope='class', autouse=True)
def cleanup(self):
cleanup_s3_location()
cleanup_s3_location(s3bucket + schema_name, region)
yield

def test_generic_tests(self, project):
# seed command
results = run_dbt(["seed"])

relation = relation_from_name(project.adapter, "base")
relation_table_model = relation_from_name(project.adapter, "table_model")
# run refresh table to disable the previous parquet file paths
project.run_sql(f"refresh table {relation}")
project.run_sql(f"refresh table {relation_table_model}")

# test command selecting base model
results = run_dbt(["test", "-m", "base"])
Expand Down Expand Up @@ -291,33 +286,3 @@ def test_generic_tests(self, project):

#class TestSnapshotTimestampGlue(BaseSnapshotTimestamp):
# pass

class S3Url(object):
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)

@property
def bucket(self):
return self._parsed.netloc

@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")

@property
def url(self):
return self._parsed.geturl()

def delete_all_keys_v2(self, client):
bucket = self.bucket
prefix = self.key

for response in client.get_paginator('list_objects_v2').paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' not in response:
continue
for content in response['Contents']:
print("Deleting: s3://" + bucket + "/" + content['Key'])
client.delete_object(Bucket=bucket, Key=content['Key'])
4 changes: 4 additions & 0 deletions tests/unit/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CATALOG_ID = "1234567890101"
DATABASE_NAME = "test_dbt_glue"
BUCKET_NAME = "test-dbt-glue"
AWS_REGION = "us-east-1"
42 changes: 39 additions & 3 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict, Optional
import unittest
from unittest import mock
from moto import mock_glue

from dbt.config import RuntimeConfig

import dbt.flags as flags
from dbt.adapters.glue import GlueAdapter
from dbt.adapters.glue.relation import SparkRelation
from tests.util import config_from_parts_or_dicts
from .util import MockAWSService


class TestGlueAdapter(unittest.TestCase):
Expand All @@ -33,8 +36,8 @@ def setUp(self):
"region": "us-east-1",
"workers": 2,
"worker_type": "G.1X",
"schema": "dbt_functional_test_01",
"database": "dbt_functional_test_01",
"schema": "dbt_unit_test_01",
"database": "dbt_unit_test_01",
}
},
"target": "test",
Expand All @@ -56,5 +59,38 @@ def test_glue_connection(self):

self.assertEqual(connection.state, "open")
self.assertEqual(connection.type, "glue")
self.assertEqual(connection.credentials.schema, "dbt_functional_test_01")
self.assertEqual(connection.credentials.schema, "dbt_unit_test_01")
self.assertIsNotNone(connection.handle)


@mock_glue
def test_get_table_type(self):
config = self._get_config()
adapter = GlueAdapter(config)

database_name = "dbt_unit_test_01"
table_name = "test_table"
mock_aws_service = MockAWSService()
mock_aws_service.create_database(name=database_name)
mock_aws_service.create_iceberg_table(table_name=table_name, database_name=database_name)
target_relation = SparkRelation.create(
schema=database_name,
identifier=table_name,
)
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
self.assertEqual(adapter.get_table_type(target_relation), "iceberg_table")

@mock_glue
def test_hudi_merge_table(self):
config = self._get_config()
adapter = GlueAdapter(config)
target_relation = SparkRelation.create(
schema="dbt_unit_test_01",
name="test_hudi_merge_table",
)
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
adapter.hudi_merge_table(target_relation, "SELECT 1", "id", "category", "empty", None, None)
96 changes: 96 additions & 0 deletions tests/unit/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional
import boto3

from .constants import AWS_REGION, BUCKET_NAME, CATALOG_ID, DATABASE_NAME


class MockAWSService:
def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID):
glue = boto3.client("glue", region_name=AWS_REGION)
glue.create_database(DatabaseInput={"Name": name}, CatalogId=catalog_id)

def create_table(
self,
table_name: str,
database_name: str = DATABASE_NAME,
catalog_id: str = CATALOG_ID,
location: Optional[str] = "auto",
):
glue = boto3.client("glue", region_name=AWS_REGION)
if location == "auto":
location = f"s3://{BUCKET_NAME}/tables/{table_name}"
glue.create_table(
CatalogId=catalog_id,
DatabaseName=database_name,
TableInput={
"Name": table_name,
"StorageDescriptor": {
"Columns": [
{
"Name": "id",
"Type": "string",
},
{
"Name": "country",
"Type": "string",
},
],
"Location": location,
},
"PartitionKeys": [
{
"Name": "dt",
"Type": "date",
},
],
"TableType": "table",
"Parameters": {
"compressionType": "snappy",
"classification": "parquet",
"projection.enabled": "false",
"typeOfData": "file",
},
},
)

def create_iceberg_table(
self,
table_name: str,
database_name: str = DATABASE_NAME,
catalog_id: str = CATALOG_ID):
glue = boto3.client("glue", region_name=AWS_REGION)
glue.create_table(
CatalogId=catalog_id,
DatabaseName=database_name,
TableInput={
"Name": table_name,
"StorageDescriptor": {
"Columns": [
{
"Name": "id",
"Type": "string",
},
{
"Name": "country",
"Type": "string",
},
{
"Name": "dt",
"Type": "date",
},
],
"Location": f"s3://{BUCKET_NAME}/tables/data/{table_name}",
},
"PartitionKeys": [
{
"Name": "dt",
"Type": "date",
},
],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"metadata_location": f"s3://{BUCKET_NAME}/tables/metadata/{table_name}/123.json",
"table_type": "iceberg",
},
},
)
37 changes: 37 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import boto3
from urllib.parse import urlparse
from dbt.config.project import PartialProject


Expand Down Expand Up @@ -110,3 +112,38 @@ def get_s3_location():
def get_role_arn():
return os.environ.get("DBT_GLUE_ROLE_ARN", f"arn:aws:iam::{get_account_id()}:role/GlueInteractiveSessionRole")


def cleanup_s3_location(path, region):
client = boto3.client("s3", region_name=region)
S3Url(path).delete_all_keys_v2(client)


class S3Url(object):
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)

@property
def bucket(self):
return self._parsed.netloc

@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")

@property
def url(self):
return self._parsed.geturl()

def delete_all_keys_v2(self, client):
bucket = self.bucket
prefix = self.key

for response in client.get_paginator('list_objects_v2').paginate(Bucket=bucket, Prefix=prefix):
if 'Contents' not in response:
continue
for content in response['Contents']:
print("Deleting: s3://" + bucket + "/" + content['Key'])
client.delete_object(Bucket=bucket, Key=content['Key'])

0 comments on commit d095ec8

Please sign in to comment.