Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

45 explode json array operation #46

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ MANIFEST
*.json
!sample_sheet1.json
!sample_sheet2.json
!sample_sheet3.json

# Unit test / coverage reports
htmlcov/
Expand Down
102 changes: 102 additions & 0 deletions dbt_automation/operations/explodejsonarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""explode elements out of a json list into their own rows"""
from logging import basicConfig, getLogger, INFO

from dbt_automation.utils.dbtproject import dbtProject
from dbt_automation.utils.columnutils import quote_columnname
from dbt_automation.utils.columnutils import make_cleaned_column_names, dedup_list

basicConfig(level=INFO)
logger = getLogger()


# pylint:disable=unused-argument,logging-fstring-interpolation
def explodejsonarray(config: dict, warehouse, project_dir: str):
"""
source_schema: name of the input schema
input_name: name of the input model
dest_schema: name of the output schema
output_model: name of the output model
columns_to_copy: list of columns to copy from the input model
json_column: name of the json column to flatten
"""

source_schema = config["source_schema"]
source_table = config["source_table"]
input_model = config.get("input_model")
input_source = config.get("input_source")
if input_model is None and input_source is None:
raise ValueError("either input_model or input_source must be specified")
dest_schema = config["dest_schema"]
output_model = config["output_model"]
columns_to_copy = config["columns_to_copy"]
json_column = config["json_column"]

model_code = f'{{{{ config(materialized="table", schema="{dest_schema}") }}}}'
model_code += "\n"

if columns_to_copy is None:
model_code += "SELECT "
elif columns_to_copy == "*":
model_code += "SELECT *, "
else:
select_list = [quote_columnname(col, warehouse.name) for col in columns_to_copy]
model_code += f"SELECT {', '.join(select_list)}, "

model_code += "\n"

json_columns = warehouse.get_json_columnspec_from_array(
source_schema, source_table, json_column
)

# convert to sql-friendly column names
sql_columns = make_cleaned_column_names(json_columns)

# after cleaning we may have duplicates
sql_columns = dedup_list(sql_columns)

if warehouse.name == "postgres":
select_list = []
for json_field, sql_column in zip(json_columns, sql_columns):
select_list.append(
warehouse.json_extract_from_array_op(
json_column, json_field, sql_column
)
)
model_code += ",".join(select_list)
model_code += "\nFROM\n"

if input_model:
model_code += "{{ref('" + input_model + "')}}"
else:
model_code += "{{source('" + source_schema + "', '" + input_source + "')}}"

model_code += "\n"

elif warehouse.name == "bigquery":
select_list = []
for json_field, sql_column in zip(json_columns, sql_columns):
select_list.append(
warehouse.json_extract_op("JVAL", json_field, sql_column)
)
model_code += ",".join(select_list)
model_code += "\nFROM\n"

if input_model:
model_code += "{{ref('" + input_model + "')}}"
else:
model_code += "{{source('" + source_schema + "', '" + input_source + "')}}"

model_code += f""" CROSS JOIN UNNEST((
SELECT JSON_EXTRACT_ARRAY(`{json_column}`, '$')
FROM """

if input_model:
model_code += "{{ref('" + input_model + "')}}"
else:
model_code += "{{source('" + source_schema + "', '" + input_source + "')}}"

model_code += ")) `JVAL`"

dbtproject = dbtProject(project_dir)
dbtproject.ensure_models_dir(dest_schema)
dbtproject.write_model(dest_schema, output_model, model_code)
78 changes: 56 additions & 22 deletions dbt_automation/utils/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""utilities for working with bigquery"""

from logging import basicConfig, getLogger, INFO
import os
import json
from logging import basicConfig, getLogger, INFO
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from google.oauth2 import service_account
import json

basicConfig(level=INFO)
logger = getLogger()
Expand All @@ -18,8 +18,10 @@ def __init__(self, conn_info=None, location=None):
self.name = "bigquery"
self.bqclient = None
if conn_info is None: # take creds from env
creds_file = open(os.getenv("GOOGLE_APPLICATION_CREDENTIALS"))
conn_info = json.load(creds_file)
with open(
os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), "r", encoding="utf-8"
) as creds_file:
conn_info = json.load(creds_file)
location = os.getenv("BIQUERY_LOCATION")

creds1 = service_account.Credentials.from_service_account_info(conn_info)
Expand Down Expand Up @@ -69,10 +71,10 @@ def get_json_columnspec(
query = self.execute(
f'''
CREATE TEMP FUNCTION jsonObjectKeys(input STRING)
RETURNS Array<String>
LANGUAGE js AS """
return Object.keys(JSON.parse(input));
""";
RETURNS Array<String>
LANGUAGE js AS """
return Object.keys(JSON.parse(input));
""";
WITH keys AS (
SELECT
jsonObjectKeys({column}) AS keys
Expand All @@ -89,6 +91,39 @@ def get_json_columnspec(
)
return [json_field["k"] for json_field in query]

def get_json_columnspec_from_array(self, schema: str, table: str, column: str):
"""get the column schema from the elements of the specified json array for this table"""
query = self.execute(
f'''
CREATE TEMP FUNCTION jsonObjectKeys(input STRING)
RETURNS Array<String>
LANGUAGE js AS """
return Object.keys(JSON.parse(input));
""";

WITH keys as (
WITH key_rows AS (
WITH `json_data` as (
SELECT JSON_EXTRACT_ARRAY(`{column}`, '$')
FROM `{schema}`.`{table}`
)
SELECT * FROM UNNEST(
(SELECT * FROM `json_data`)
) as key
)
SELECT jsonObjectKeys(`key`)
AS key
FROM key_rows
)
SELECT DISTINCT k
FROM keys
CROSS JOIN UNNEST(keys.key)
AS k
''',
location=self.location,
)
return [json_field["k"] for json_field in query]

def schema_exists_(self, schema: str) -> bool:
"""checks if the schema exists"""
try:
Expand Down Expand Up @@ -140,7 +175,7 @@ def insert_row(self, schema: str, table: str, row: dict):
def json_extract_op(self, json_column: str, json_field: str, sql_column: str):
"""outputs a sql query snippet for extracting a json field"""
json_field = json_field.replace("'", "\\'")
return f"json_value({json_column}, '$.\"{json_field}\"') as `{sql_column}`"
return f"json_value(`{json_column}`, '$.\"{json_field}\"') as `{sql_column}`"

def close(self):
"""closing the connection and releasing system resources"""
Expand All @@ -152,25 +187,25 @@ def close(self):
return True

def generate_profiles_yaml_dbt(self, project_name, default_schema):
"""Generates the profiles.yml dictionary object for dbt"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

"""
<project_name>:
Generates the profiles.yml dictionary object for dbt
<project_name>:
outputs:
prod:
keyfile_json:
location:
prod:
keyfile_json:
location:
method: service-account-json
project:
schema:
project:
schema:
threads: 4
type: bigquery
target: prod
"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

profiles_yml = {
f"{project_name}": {
"outputs": {
Expand All @@ -179,7 +214,6 @@ def generate_profiles_yaml_dbt(self, project_name, default_schema):
"location": self.location,
"method": "service-account-json",
"project": self.conn_info["project_id"],
"method": "service-account-json",
"schema": default_schema,
"threads": 4,
"type": "bigquery",
Expand Down
49 changes: 34 additions & 15 deletions dbt_automation/utils/postgres.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""helpers for postgres"""
import os
from logging import basicConfig, getLogger, INFO
import psycopg2
import os

basicConfig(level=INFO)
logger = getLogger()
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_tables(self, schema: str) -> list:
def get_schemas(self) -> list:
"""returns the list of schema names in the given database connection"""
resultset = self.execute(
f"""
"""
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%' AND nspname != 'information_schema';
Expand Down Expand Up @@ -130,6 +130,18 @@ def get_json_columnspec(self, schema: str, table: str, column: str):
)
]

def get_json_columnspec_from_array(self, schema: str, table: str, column: str):
"""get the column schema from the elements of the specified json array for this table"""
return [
x[0]
for x in self.execute(
f"""SELECT DISTINCT
jsonb_object_keys(jsonb_array_elements({column}::jsonb)) AS key
FROM "{schema}"."{table}"
"""
)
]

def ensure_schema(self, schema: str):
"""creates the schema if it doesn't exist"""
self.runcmd(f"CREATE SCHEMA IF NOT EXISTS {schema};")
Expand Down Expand Up @@ -164,35 +176,42 @@ def json_extract_op(self, json_column: str, json_field: str, sql_column: str):
"""outputs a sql query snippet for extracting a json field"""
return f"{json_column}::json->>'{json_field}' as \"{sql_column}\""

def json_extract_from_array_op(
self, json_column: str, json_field: str, sql_column: str
):
"""outputs a sql query snippet for extracting a json field from elements of a list into several rows"""
return f"jsonb_array_elements({json_column}::jsonb)->>'{json_field}' as \"{sql_column}\""

def close(self):
"""closes the connection"""
try:
self.connection.close()
except Exception:
except Exception: # pylint:disable=broad-except
logger.error("something went wrong while closing the postgres connection")

return True

def generate_profiles_yaml_dbt(self, project_name, default_schema):
"""Generates the profiles.yml dictionary object for dbt"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

"""
<project_name>:
Generates the profiles.yml dictionary object for dbt
<project_name>:
outputs:
prod:
dbname:
host:
password:
prod:
dbname:
host:
password:
port: 5432
user: airbyte_user
schema:
schema:
threads: 4
type: postgres
target: prod
"""
if project_name is None or default_schema is None:
raise ValueError("project_name and default_schema are required")

target = "prod"

profiles_yml = {
f"{project_name}": {
"outputs": {
Expand Down
2 changes: 2 additions & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dbt_automation.operations.flattenjson import flattenjson
from dbt_automation.operations.regexextraction import regex_extraction
from dbt_automation.operations.scaffold import scaffold
from dbt_automation.operations.explodejsonarray import explodejsonarray

OPERATIONS_DICT = {
"flatten": flatten_operation,
Expand All @@ -34,6 +35,7 @@
"renamecolumns": rename_columns,
"regexextraction": regex_extraction,
"scaffold": scaffold,
"explodejsonarray": explodejsonarray,
}

load_dotenv("./../dbconnection.env")
Expand Down
7 changes: 7 additions & 0 deletions seeds/sample_sheet3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{
"_airbyte_ab_id": "006b18b2-cccd-47f1-a9dc-5638d2d1abc7",
"_airbyte_data": "{\"data\": [{\"NGO\": \"IMAGE\", \"SPOC\": \"SPOC C\"}, {\"NGO\": \"FDSR\", \"SPOC\": \"SPOC B\"}, {\"NGO\": \"CRC\", \"SPOC\": \"SPOC A\"}]}",
"_airbyte_emitted_at": "2023-09-28 08:50:17+00"
}
]
Loading