diff --git a/.gitignore b/.gitignore index 04b7133..2eb80ab 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ MANIFEST *.json !sample_sheet1.json !sample_sheet2.json +!sample_sheet3.json # Unit test / coverage reports htmlcov/ diff --git a/dbt_automation/operations/explodejsonarray.py b/dbt_automation/operations/explodejsonarray.py new file mode 100644 index 0000000..9196ade --- /dev/null +++ b/dbt_automation/operations/explodejsonarray.py @@ -0,0 +1,102 @@ +"""explode elements out of a json list into their own rows""" +from logging import basicConfig, getLogger, INFO + +from dbt_automation.utils.dbtproject import dbtProject +from dbt_automation.utils.columnutils import quote_columnname +from dbt_automation.utils.columnutils import make_cleaned_column_names, dedup_list + +basicConfig(level=INFO) +logger = getLogger() + + +# pylint:disable=unused-argument,logging-fstring-interpolation +def explodejsonarray(config: dict, warehouse, project_dir: str): + """ + source_schema: name of the input schema + input_name: name of the input model + dest_schema: name of the output schema + output_model: name of the output model + columns_to_copy: list of columns to copy from the input model + json_column: name of the json column to flatten + """ + + source_schema = config["source_schema"] + source_table = config["source_table"] + input_model = config.get("input_model") + input_source = config.get("input_source") + if input_model is None and input_source is None: + raise ValueError("either input_model or input_source must be specified") + dest_schema = config["dest_schema"] + output_model = config["output_model"] + columns_to_copy = config["columns_to_copy"] + json_column = config["json_column"] + + model_code = f'{{{{ config(materialized="table", schema="{dest_schema}") }}}}' + model_code += "\n" + + if columns_to_copy is None: + model_code += "SELECT " + elif columns_to_copy == "*": + model_code += "SELECT *, " + else: + select_list = [quote_columnname(col, warehouse.name) for col in columns_to_copy] + model_code += f"SELECT {', '.join(select_list)}, " + + model_code += "\n" + + json_columns = warehouse.get_json_columnspec_from_array( + source_schema, source_table, json_column + ) + + # convert to sql-friendly column names + sql_columns = make_cleaned_column_names(json_columns) + + # after cleaning we may have duplicates + sql_columns = dedup_list(sql_columns) + + if warehouse.name == "postgres": + select_list = [] + for json_field, sql_column in zip(json_columns, sql_columns): + select_list.append( + warehouse.json_extract_from_array_op( + json_column, json_field, sql_column + ) + ) + model_code += ",".join(select_list) + model_code += "\nFROM\n" + + if input_model: + model_code += "{{ref('" + input_model + "')}}" + else: + model_code += "{{source('" + source_schema + "', '" + input_source + "')}}" + + model_code += "\n" + + elif warehouse.name == "bigquery": + select_list = [] + for json_field, sql_column in zip(json_columns, sql_columns): + select_list.append( + warehouse.json_extract_op("JVAL", json_field, sql_column) + ) + model_code += ",".join(select_list) + model_code += "\nFROM\n" + + if input_model: + model_code += "{{ref('" + input_model + "')}}" + else: + model_code += "{{source('" + source_schema + "', '" + input_source + "')}}" + + model_code += f""" CROSS JOIN UNNEST(( + SELECT JSON_EXTRACT_ARRAY(`{json_column}`, '$') + FROM """ + + if input_model: + model_code += "{{ref('" + input_model + "')}}" + else: + model_code += "{{source('" + source_schema + "', '" + input_source + "')}}" + + model_code += ")) `JVAL`" + + dbtproject = dbtProject(project_dir) + dbtproject.ensure_models_dir(dest_schema) + dbtproject.write_model(dest_schema, output_model, model_code) diff --git a/dbt_automation/utils/bigquery.py b/dbt_automation/utils/bigquery.py index de10c5d..00ead65 100644 --- a/dbt_automation/utils/bigquery.py +++ b/dbt_automation/utils/bigquery.py @@ -1,11 +1,11 @@ """utilities for working with bigquery""" -from logging import basicConfig, getLogger, INFO import os +import json +from logging import basicConfig, getLogger, INFO from google.cloud import bigquery from google.cloud.exceptions import NotFound from google.oauth2 import service_account -import json basicConfig(level=INFO) logger = getLogger() @@ -18,8 +18,10 @@ def __init__(self, conn_info=None, location=None): self.name = "bigquery" self.bqclient = None if conn_info is None: # take creds from env - creds_file = open(os.getenv("GOOGLE_APPLICATION_CREDENTIALS")) - conn_info = json.load(creds_file) + with open( + os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), "r", encoding="utf-8" + ) as creds_file: + conn_info = json.load(creds_file) location = os.getenv("BIQUERY_LOCATION") creds1 = service_account.Credentials.from_service_account_info(conn_info) @@ -69,10 +71,10 @@ def get_json_columnspec( query = self.execute( f''' CREATE TEMP FUNCTION jsonObjectKeys(input STRING) - RETURNS Array - LANGUAGE js AS """ - return Object.keys(JSON.parse(input)); - """; + RETURNS Array + LANGUAGE js AS """ + return Object.keys(JSON.parse(input)); + """; WITH keys AS ( SELECT jsonObjectKeys({column}) AS keys @@ -89,6 +91,39 @@ def get_json_columnspec( ) return [json_field["k"] for json_field in query] + def get_json_columnspec_from_array(self, schema: str, table: str, column: str): + """get the column schema from the elements of the specified json array for this table""" + query = self.execute( + f''' + CREATE TEMP FUNCTION jsonObjectKeys(input STRING) + RETURNS Array + LANGUAGE js AS """ + return Object.keys(JSON.parse(input)); + """; + + WITH keys as ( + WITH key_rows AS ( + WITH `json_data` as ( + SELECT JSON_EXTRACT_ARRAY(`{column}`, '$') + FROM `{schema}`.`{table}` + ) + SELECT * FROM UNNEST( + (SELECT * FROM `json_data`) + ) as key + ) + SELECT jsonObjectKeys(`key`) + AS key + FROM key_rows + ) + SELECT DISTINCT k + FROM keys + CROSS JOIN UNNEST(keys.key) + AS k + ''', + location=self.location, + ) + return [json_field["k"] for json_field in query] + def schema_exists_(self, schema: str) -> bool: """checks if the schema exists""" try: @@ -140,7 +175,7 @@ def insert_row(self, schema: str, table: str, row: dict): def json_extract_op(self, json_column: str, json_field: str, sql_column: str): """outputs a sql query snippet for extracting a json field""" json_field = json_field.replace("'", "\\'") - return f"json_value({json_column}, '$.\"{json_field}\"') as `{sql_column}`" + return f"json_value(`{json_column}`, '$.\"{json_field}\"') as `{sql_column}`" def close(self): """closing the connection and releasing system resources""" @@ -152,25 +187,25 @@ def close(self): return True def generate_profiles_yaml_dbt(self, project_name, default_schema): - """Generates the profiles.yml dictionary object for dbt""" - if project_name is None or default_schema is None: - raise ValueError("project_name and default_schema are required") - - target = "prod" - """ - : + Generates the profiles.yml dictionary object for dbt + : outputs: - prod: - keyfile_json: - location: + prod: + keyfile_json: + location: method: service-account-json - project: - schema: + project: + schema: threads: 4 type: bigquery target: prod """ + if project_name is None or default_schema is None: + raise ValueError("project_name and default_schema are required") + + target = "prod" + profiles_yml = { f"{project_name}": { "outputs": { @@ -179,7 +214,6 @@ def generate_profiles_yaml_dbt(self, project_name, default_schema): "location": self.location, "method": "service-account-json", "project": self.conn_info["project_id"], - "method": "service-account-json", "schema": default_schema, "threads": 4, "type": "bigquery", diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index 8a9213a..043377d 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -1,7 +1,7 @@ """helpers for postgres""" +import os from logging import basicConfig, getLogger, INFO import psycopg2 -import os basicConfig(level=INFO) logger = getLogger() @@ -71,7 +71,7 @@ def get_tables(self, schema: str) -> list: def get_schemas(self) -> list: """returns the list of schema names in the given database connection""" resultset = self.execute( - f""" + """ SELECT nspname FROM pg_namespace WHERE nspname NOT LIKE 'pg_%' AND nspname != 'information_schema'; @@ -130,6 +130,18 @@ def get_json_columnspec(self, schema: str, table: str, column: str): ) ] + def get_json_columnspec_from_array(self, schema: str, table: str, column: str): + """get the column schema from the elements of the specified json array for this table""" + return [ + x[0] + for x in self.execute( + f"""SELECT DISTINCT + jsonb_object_keys(jsonb_array_elements({column}::jsonb)) AS key + FROM "{schema}"."{table}" + """ + ) + ] + def ensure_schema(self, schema: str): """creates the schema if it doesn't exist""" self.runcmd(f"CREATE SCHEMA IF NOT EXISTS {schema};") @@ -164,35 +176,42 @@ def json_extract_op(self, json_column: str, json_field: str, sql_column: str): """outputs a sql query snippet for extracting a json field""" return f"{json_column}::json->>'{json_field}' as \"{sql_column}\"" + def json_extract_from_array_op( + self, json_column: str, json_field: str, sql_column: str + ): + """outputs a sql query snippet for extracting a json field from elements of a list into several rows""" + return f"jsonb_array_elements({json_column}::jsonb)->>'{json_field}' as \"{sql_column}\"" + def close(self): + """closes the connection""" try: self.connection.close() - except Exception: + except Exception: # pylint:disable=broad-except logger.error("something went wrong while closing the postgres connection") return True def generate_profiles_yaml_dbt(self, project_name, default_schema): - """Generates the profiles.yml dictionary object for dbt""" - if project_name is None or default_schema is None: - raise ValueError("project_name and default_schema are required") - - target = "prod" - """ - : + Generates the profiles.yml dictionary object for dbt + : outputs: - prod: - dbname: - host: - password: + prod: + dbname: + host: + password: port: 5432 user: airbyte_user - schema: + schema: threads: 4 type: postgres target: prod """ + if project_name is None or default_schema is None: + raise ValueError("project_name and default_schema are required") + + target = "prod" + profiles_yml = { f"{project_name}": { "outputs": { diff --git a/scripts/main.py b/scripts/main.py index 2dddedf..5757271 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -20,6 +20,7 @@ from dbt_automation.operations.flattenjson import flattenjson from dbt_automation.operations.regexextraction import regex_extraction from dbt_automation.operations.scaffold import scaffold +from dbt_automation.operations.explodejsonarray import explodejsonarray OPERATIONS_DICT = { "flatten": flatten_operation, @@ -34,6 +35,7 @@ "renamecolumns": rename_columns, "regexextraction": regex_extraction, "scaffold": scaffold, + "explodejsonarray": explodejsonarray, } load_dotenv("./../dbconnection.env") diff --git a/seeds/sample_sheet3.json b/seeds/sample_sheet3.json new file mode 100644 index 0000000..a6d2975 --- /dev/null +++ b/seeds/sample_sheet3.json @@ -0,0 +1,7 @@ +[ + { + "_airbyte_ab_id": "006b18b2-cccd-47f1-a9dc-5638d2d1abc7", + "_airbyte_data": "{\"data\": [{\"NGO\": \"IMAGE\", \"SPOC\": \"SPOC C\"}, {\"NGO\": \"FDSR\", \"SPOC\": \"SPOC B\"}, {\"NGO\": \"CRC\", \"SPOC\": \"SPOC A\"}]}", + "_airbyte_emitted_at": "2023-09-28 08:50:17+00" + } +] \ No newline at end of file diff --git a/seeds/seed.py b/seeds/seed.py index 0559e3c..bf3c604 100644 --- a/seeds/seed.py +++ b/seeds/seed.py @@ -1,12 +1,10 @@ """This script seeds airbyte's raw data into test warehouse""" - -import argparse, os +import os +import argparse +import json from logging import basicConfig, getLogger, INFO from dbt_automation.utils.warehouseclient import get_client from dotenv import load_dotenv -import csv -from pathlib import Path -import json from google.cloud import bigquery @@ -21,29 +19,29 @@ load_dotenv("dbconnection.env") -tablename = "_airbyte_raw_Sheet1" -json_file = "seeds/sample_sheet1.json" - for json_file, tablename in zip( - ["seeds/sample_sheet1.json", "seeds/sample_sheet2.json"], - ["_airbyte_raw_Sheet1", "_airbyte_raw_Sheet2"], + [ + "seeds/sample_sheet1.json", + "seeds/sample_sheet2.json", + "seeds/sample_sheet3.json", + ], + ["_airbyte_raw_Sheet1", "_airbyte_raw_Sheet2", "_airbyte_raw_Sheet3"], ): logger.info("seeding %s into %s", json_file, tablename) data = [] - with open(json_file, "r") as file: + with open(json_file, "r", encoding="utf-8") as file: data = json.load(file) columns = ["_airbyte_ab_id", "_airbyte_data", "_airbyte_emitted_at"] # schema check; expecting only airbyte raw data for row in data: - schema_check = [True if key in columns else False for key in row.keys()] + schema_check = [key in columns for key in row.keys()] if all(schema_check) is False: - raise Exception("Schema mismatch") + raise Exception("Schema mismatch") # pylint:disable=broad-exception-raised if args.warehouse == "postgres": - logger.info("Found postgres warehouse") conn_info = { "host": os.getenv("TEST_PG_DBHOST"), "port": os.getenv("TEST_PG_DBPORT"), @@ -52,6 +50,12 @@ "password": os.getenv("TEST_PG_DBPASSWORD"), } schema = os.getenv("TEST_PG_DBSCHEMA_SRC") + logger.info( + "Found postgres warehouse %s/%s/%s", + conn_info["host"], + conn_info["database"], + schema, + ) wc_client = get_client(args.warehouse, conn_info) @@ -74,22 +78,25 @@ wc_client.runcmd(create_table_query) wc_client.runcmd(truncate_table_query) - """ - INSERT INTO your_table_name (column1, column2, column3, ...) - VALUES ({}, {}, {}, ...); - """ # seed sample json data into the newly table created + # INSERT INTO your_table_name (column1, column2, column3, ...) + # VALUES ({}, {}, {}, ...); logger.info("seeding sample json data") for row in data: # Execute the insert query with the data from the CSV - insert_query = f"""INSERT INTO {schema}."{tablename}" ({', '.join(columns)}) VALUES ('{row['_airbyte_ab_id']}', JSON '{row['_airbyte_data']}', '{row['_airbyte_emitted_at']}')""" + insert_query = f""" + INSERT INTO {schema}."{tablename}" + ({', '.join(columns)}) + VALUES + ('{row['_airbyte_ab_id']}', JSON '{row['_airbyte_data']}', '{row['_airbyte_emitted_at']}') + """ wc_client.runcmd(insert_query) - if args.warehouse == "bigquery": - logger.info("Found bigquery warehouse") + elif args.warehouse == "bigquery": conn_info = json.loads(os.getenv("TEST_BG_SERVICEJSON")) location = os.getenv("TEST_BG_LOCATION") test_dataset = os.getenv("TEST_BG_DATASET_SRC") + logger.info("Found bigquery warehouse %s", test_dataset) wc_client = get_client(args.warehouse, conn_info) @@ -99,7 +106,7 @@ logger.info("creating the dataset") dataset = wc_client.bqclient.create_dataset(dataset, timeout=30, exists_ok=True) - logger.info("created dataset : {}".format(dataset.dataset_id)) + logger.info("created dataset : %s", dataset.dataset_id) # create the staging table if its does not exist table_schema = [ @@ -123,10 +130,10 @@ # seed data insert_query = f""" - INSERT INTO `{conn_info['project_id']}.{dataset.dataset_id}.{tablename}` (_airbyte_ab_id, _airbyte_data, _airbyte_emitted_at) - VALUES + INSERT INTO `{conn_info['project_id']}.{dataset.dataset_id}.{tablename}` (_airbyte_ab_id, _airbyte_data, _airbyte_emitted_at) + VALUES """ - insert_query_values = ",".join( + insert_query_values = ",".join( # pylint:disable=invalid-name [ f"""('{row["_airbyte_ab_id"]}', '{row["_airbyte_data"]}', '{row["_airbyte_emitted_at"]}')""" for row in data diff --git a/seeds/seed_001.yml b/seeds/seed_001.yml deleted file mode 100644 index fd1bfce..0000000 --- a/seeds/seed_001.yml +++ /dev/null @@ -1,9 +0,0 @@ -version: 1 -description: "Yaml template to get you started on automating your dbt work. DO NOT EDIT this, make a copy and use" -warehouse: bigquery - -seed_data: - - schema: tests_001 - tables: - - name: model_001 - csv: seed_001.csv diff --git a/tests/warehouse/test_postgres_ops.py b/tests/warehouse/test_postgres_ops.py index 687ac7b..6a0f0b8 100644 --- a/tests/warehouse/test_postgres_ops.py +++ b/tests/warehouse/test_postgres_ops.py @@ -1,8 +1,8 @@ -import pytest +"""tests for dbt operations using a postgres warehouse""" import os from pathlib import Path import math -import subprocess, sys +import subprocess from logging import basicConfig, getLogger, INFO from dbt_automation.operations.droprenamecolumns import rename_columns, drop_columns from dbt_automation.utils.warehouseclient import get_client @@ -43,6 +43,7 @@ class TestPostgresOperations: @staticmethod def execute_dbt(cmd: str, select_model: str = None): + """runs a dbt command""" try: select_cli = ["--select", select_model] if select_model is not None else [] subprocess.check_call( @@ -62,8 +63,12 @@ def execute_dbt(cmd: str, select_model: str = None): ], ) except subprocess.CalledProcessError as e: - logger.error(f"dbt {cmd} failed with {e.returncode}") - raise Exception(f"Something went wrong while running dbt {cmd}") + logger.error( # pylint:disable=logging-fstring-interpolation + f"dbt {cmd} failed with {e.returncode}" + ) + raise Exception( # pylint:disable=logging-fstring-interpolation, broad-exception-raised, raise-missing-from + f"Something went wrong while running dbt {cmd}" + ) def test_scaffold(self, tmpdir): """This will setup the dbt repo to run dbt commands after running a test operation""" @@ -114,7 +119,7 @@ def test_flatten(self): TestPostgresOperations.execute_dbt("run", "Sheet1") TestPostgresOperations.execute_dbt("run", "Sheet2") logger.info("inside test flatten") - logger.info( + logger.info( # pylint:disable=logging-fstring-interpolation f"inside project directory : {TestPostgresOperations.test_project_dir}" ) assert "Sheet1" in TestPostgresOperations.wc_client.get_tables( @@ -288,8 +293,8 @@ def test_castdatatypes(self): assert "measure2" in cols table_data = wc_client.get_table_data("pytest_intermediate", output_name, 1) # TODO: do stronger check here; fetch datatype from warehouse and then compare/assert - assert type(table_data[0]["measure1"]) == int - assert type(table_data[0]["measure2"]) == int + assert isinstance(table_data[0]["measure1"], int) + assert isinstance(table_data[0]["measure2"], int) def test_arithmetic_add(self): """test arithmetic addition""" @@ -409,11 +414,11 @@ def test_arithmetic_div(self): assert ( math.ceil(table_data[0]["measure1"] / table_data[0]["measure2"]) if table_data[0]["measure2"] != 0 - else None + else "div-0" == ( math.ceil(table_data[0]["div_col"]) if table_data[0]["div_col"] is not None - else None + else "div-0" ) )