Skip to content

Commit

Permalink
Merge pull request #9 from kameshsampath/de-snowpark-py-update
Browse files Browse the repository at this point in the history
(fix): Simplify and Context
  • Loading branch information
iamontheinet authored Jun 26, 2024
2 parents 648ed20 + 60ab3a4 commit d6784e8
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 62 deletions.
65 changes: 46 additions & 19 deletions app/05_raw_data.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,97 @@
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Hands-On Lab: Data Engineering with Snowpark
# Script: 02_load_raw.py
# Author: Jeremiah Hansen, Caleb Baechtold
# Last Updated: 1/9/2023
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------

import time
from snowflake.snowpark import Session


POS_TABLES = ['country', 'franchise', 'location', 'menu', 'truck', 'order_header', 'order_detail']
CUSTOMER_TABLES = ['customer_loyalty']
POS_TABLES = [
"country",
"franchise",
"location",
"menu",
"truck",
"order_header",
"order_detail",
]
CUSTOMER_TABLES = ["customer_loyalty"]
TABLE_DICT = {
"pos": {"schema": "RAW_POS", "tables": POS_TABLES},
"customer": {"schema": "RAW_CUSTOMER", "tables": CUSTOMER_TABLES}
"customer": {"schema": "RAW_CUSTOMER", "tables": CUSTOMER_TABLES},
}

# SNOWFLAKE ADVANTAGE: Schema detection
# SNOWFLAKE ADVANTAGE: Data ingestion with COPY
# SNOWFLAKE ADVANTAGE: Snowflake Tables (not file-based)


def load_raw_table(session, tname=None, s3dir=None, year=None, schema=None):
session.use_schema(schema)
if year is None:
location = "@external.frostbyte_raw_stage/{}/{}".format(s3dir, tname)
else:
print('\tLoading year {}'.format(year))
location = "@external.frostbyte_raw_stage/{}/{}/year={}".format(s3dir, tname, year)

print("\tLoading year {}".format(year))
location = "@external.frostbyte_raw_stage/{}/{}/year={}".format(
s3dir, tname, year
)

# we can infer schema using the parquet read option
df = session.read.option("compression", "snappy") \
.parquet(location)
df = session.read.option("compression", "snappy").parquet(location)
df.copy_into_table("{}".format(tname))


# SNOWFLAKE ADVANTAGE: Warehouse elasticity (dynamic scaling)


def load_all_raw_tables(session):
_ = session.sql("ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE").collect()
_ = session.sql(
"ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE"
).collect()

for s3dir, data in TABLE_DICT.items():
tnames = data['tables']
schema = data['schema']
tnames = data["tables"]
schema = data["schema"]
for tname in tnames:
print("Loading {}".format(tname))
# Only load the first 3 years of data for the order tables at this point
# We will load the 2022 data later in the lab
if tname in ['order_header', 'order_detail']:
for year in ['2019', '2020', '2021']:
load_raw_table(session, tname=tname, s3dir=s3dir, year=year, schema=schema)
if tname in ["order_header", "order_detail"]:
for year in ["2019", "2020", "2021"]:
load_raw_table(
session, tname=tname, s3dir=s3dir, year=year, schema=schema
)
else:
load_raw_table(session, tname=tname, s3dir=s3dir, schema=schema)

_ = session.sql("ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL").collect()


def validate_raw_tables(session):
# check column names from the inferred schema
for tname in POS_TABLES:
print('{}: \n\t{}\n'.format(tname, session.table('RAW_POS.{}'.format(tname)).columns))
print(
"{}: \n\t{}\n".format(
tname, session.table("RAW_POS.{}".format(tname)).columns
)
)

for tname in CUSTOMER_TABLES:
print('{}: \n\t{}\n'.format(tname, session.table('RAW_CUSTOMER.{}'.format(tname)).columns))
print(
"{}: \n\t{}\n".format(
tname, session.table("RAW_CUSTOMER.{}".format(tname)).columns
)
)


# For local debugging
if __name__ == "__main__":
# Create a local Snowpark session
with Session.builder.getOrCreate() as session:
# Set the right database context to use
session.use_database("HOL_DB")
load_all_raw_tables(session)
validate_raw_tables(session)
validate_raw_tables(session)
79 changes: 56 additions & 23 deletions app/06_load_daily_city_metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from snowflake.snowpark import Session
import snowflake.snowpark.functions as F

def table_exists(session, schema='', name=''):
exists = session.sql("SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS".format(schema, name)).collect()[0]['TABLE_EXISTS']

def table_exists(session, schema="", name=""):
exists = session.sql(
"SELECT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{}' AND TABLE_NAME = '{}') AS TABLE_EXISTS".format(
schema, name
)
).collect()[0]["TABLE_EXISTS"]
return exists


def main(session: Session) -> str:
_ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE').collect()
_ = session.sql(
"ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XLARGE WAIT_FOR_COMPLETION = TRUE"
).collect()
schema_name = "HOL_SCHEMA"
table_name = "DAILY_CITY_METRICS"

Expand All @@ -17,41 +25,66 @@ def main(session: Session) -> str:
location = session.table("RAW_POS.LOCATION")

# Join the tables
orders = order_header.join(order_detail, order_header['ORDER_ID'] == order_detail['ORDER_ID'])
orders = orders.join(location, orders['LOCATION_ID'] == location['LOCATION_ID'])
order_detail = orders.join(history_day, (F.builtin("DATE")(order_header['ORDER_TS']) == history_day['DATE_VALID_STD']) & (orders['ISO_COUNTRY_CODE'] == history_day['COUNTRY']) & (orders['CITY'] == history_day['CITY_NAME']))
orders = order_header.join(
order_detail, order_header["ORDER_ID"] == order_detail["ORDER_ID"]
)
orders = orders.join(location, orders["LOCATION_ID"] == location["LOCATION_ID"])
order_detail = orders.join(
history_day,
(F.builtin("DATE")(order_header["ORDER_TS"]) == history_day["DATE_VALID_STD"])
& (orders["ISO_COUNTRY_CODE"] == history_day["COUNTRY"])
& (orders["CITY"] == history_day["CITY_NAME"]),
)

# Aggregate the data
final_agg = order_detail.group_by(F.col('DATE_VALID_STD'), F.col('CITY_NAME'), F.col('ISO_COUNTRY_CODE')) \
.agg( \
F.sum('PRICE').alias('DAILY_SALES_SUM'), \
F.avg('AVG_TEMPERATURE_AIR_2M_F').alias("AVG_TEMPERATURE_F"), \
F.avg("TOT_PRECIPITATION_IN").alias("AVG_PRECIPITATION_IN"), \
) \
.select(F.col("DATE_VALID_STD").alias("DATE"), F.col("CITY_NAME"), F.col("ISO_COUNTRY_CODE").alias("COUNTRY_DESC"), \
F.builtin("ZEROIFNULL")(F.col("DAILY_SALES_SUM")).alias("DAILY_SALES"), \
F.round(F.col("AVG_TEMPERATURE_F"), 2).alias("AVG_TEMPERATURE_FAHRENHEIT"), \
F.round(F.col("AVG_PRECIPITATION_IN"), 2).alias("AVG_PRECIPITATION_INCHES"), \
)
final_agg = (
order_detail.group_by(
F.col("DATE_VALID_STD"), F.col("CITY_NAME"), F.col("ISO_COUNTRY_CODE")
)
.agg(
F.sum("PRICE").alias("DAILY_SALES_SUM"),
F.avg("AVG_TEMPERATURE_AIR_2M_F").alias("AVG_TEMPERATURE_F"),
F.avg("TOT_PRECIPITATION_IN").alias("AVG_PRECIPITATION_IN"),
)
.select(
F.col("DATE_VALID_STD").alias("DATE"),
F.col("CITY_NAME"),
F.col("ISO_COUNTRY_CODE").alias("COUNTRY_DESC"),
F.builtin("ZEROIFNULL")(F.col("DAILY_SALES_SUM")).alias("DAILY_SALES"),
F.round(F.col("AVG_TEMPERATURE_F"), 2).alias("AVG_TEMPERATURE_FAHRENHEIT"),
F.round(F.col("AVG_PRECIPITATION_IN"), 2).alias("AVG_PRECIPITATION_INCHES"),
)
)

session.use_schema(schema_name)
# If the table doesn't exist then create it
if not table_exists(session, schema=schema_name, name=table_name):
final_agg.write.mode("overwrite").save_as_table(table_name)
_ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL').collect()
_ = session.sql("ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL").collect()
return f"Successfully created {table_name}"
# Otherwise update it
else:
cols_to_update = {c: final_agg[c] for c in final_agg.schema.names}

dcm = session.table(table_name)
dcm.merge(final_agg, (dcm['DATE'] == final_agg['DATE']) & (dcm['CITY_NAME'] == final_agg['CITY_NAME']) & (dcm['COUNTRY_DESC'] == final_agg['COUNTRY_DESC']), \
[F.when_matched().update(cols_to_update), F.when_not_matched().insert(cols_to_update)])
_ = session.sql('ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL').collect()
dcm.merge(
final_agg,
(dcm["DATE"] == final_agg["DATE"])
& (dcm["CITY_NAME"] == final_agg["CITY_NAME"])
& (dcm["COUNTRY_DESC"] == final_agg["COUNTRY_DESC"]),
[
F.when_matched().update(cols_to_update),
F.when_not_matched().insert(cols_to_update),
],
)
_ = session.sql("ALTER WAREHOUSE HOL_WH SET WAREHOUSE_SIZE = XSMALL").collect()
return f"Successfully updated {table_name}"



# For local debugging
if __name__ == "__main__":
# Create a local Snowpark session
with Session.builder.getOrCreate() as session:
main(session)
# Set the right database context to use
session.use_database("HOL_DB")
main(session)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
snowflake-snowpark-python
snowflake-snowpark-python
snowflake
14 changes: 1 addition & 13 deletions steps/03_git_config.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ GRANT OWNERSHIP ON SCHEMA PUBLIC TO ROLE GIT_ADMIN;
USE ROLE GIT_ADMIN;
USE DATABASE GIT_REPO;
USE SCHEMA PUBLIC;
CREATE OR REPLACE SECRET GIT_SECRET
TYPE = PASSWORD
USERNAME = '<your_git_user'
PASSWORD = '<your_personal_access_token>';


--Create an API integration for interacting with the repository API
USE ROLE ACCOUNTADMIN;
Expand All @@ -30,20 +27,11 @@ USE ROLE GIT_ADMIN;
CREATE OR REPLACE API INTEGRATION GIT_API_INTEGRATION
API_PROVIDER = GIT_HTTPS_API
API_ALLOWED_PREFIXES = ('https://github.com/<your_git_user>')
ALLOWED_AUTHENTICATION_SECRETS = (GIT_SECRET)
ENABLED = TRUE;

CREATE OR REPLACE GIT REPOSITORY DE_QUICKSTART
API_INTEGRATION = GIT_API_INTEGRATION
GIT_CREDENTIALS = GIT_SECRET
ORIGIN = '<your git repo URL ending in .git>';

SHOW GIT BRANCHES IN DE_QUICKSTART;
ls @DE_QUICKSTART/branches/main;

USE ROLE ACCOUNTADMIN;
SET MY_USER = CURRENT_USER();
EXECUTE IMMEDIATE
FROM @GIT_REPO.PUBLIC.DE_QUICKSTART/branches/main/steps/03_setup_snowflake.sql
USING (MY_USER=>$MY_USER);

26 changes: 20 additions & 6 deletions steps/07_deploy_task_dag.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------
# Hands-On Lab: Intro to Data Engineering with Snowpark Python
# Script: 07_deploy_task_dag.py
# Author: Jeremiah Hansen
# Last Updated: 9/26/2023
#------------------------------------------------------------------------------
# ------------------------------------------------------------------------------

# SNOWFLAKE ADVANTAGE: Snowpark Python API
# SNOWFLAKE ADVANTAGE: Snowpark Python Task DAG API
Expand All @@ -16,12 +16,18 @@
from snowflake.core.task import StoredProcedureCall, Task
from snowflake.core.task.dagv1 import DAGOperation, DAG, DAGTask


# Create the tasks using the DAG API
def main(session: Session) -> str:
database_name = "HOL_DB"
schema_name = "HOL_SCHEMA"
warehouse_name = "HOL_WH"

# set database context
session.use_database(database_name)
# set database schema context
session.use_schema(schema_name)

api_root = Root(session)
schema = api_root.databases[database_name].schemas[schema_name]
tasks = schema.tasks
Expand All @@ -30,16 +36,24 @@ def main(session: Session) -> str:
dag_name = "HOL_DAG"
dag = DAG(dag_name, schedule=timedelta(days=1), warehouse=warehouse_name)
with dag:
dag_task1 = DAGTask("LOAD_ORDER_DETAIL_TASK", definition="CALL LOAD_EXCEL_WORKSHEET_TO_TABLE_SP(BUILD_SCOPED_FILE_URL(@FROSTBYTE_RAW_STAGE, 'intro/order_detail.xlsx'), 'order_detail', 'ORDER_DETAIL')", warehouse=warehouse_name)
dag_task2 = DAGTask("LOAD_DAILY_CITY_METRICS_TASK", definition="CALL LOAD_DAILY_CITY_METRICS_SP()", warehouse=warehouse_name)
dag_task1 = DAGTask(
"LOAD_ORDER_DETAIL_TASK",
definition="CALL LOAD_EXCEL_WORKSHEET_TO_TABLE_SP(BUILD_SCOPED_FILE_URL(@EXTERNAL.FROSTBYTE_RAW_STAGE, 'intro/order_detail.xlsx'), 'order_detail', 'ORDER_DETAIL')",
warehouse=warehouse_name,
)
dag_task2 = DAGTask(
"LOAD_DAILY_CITY_METRICS_TASK",
definition="CALL LOAD_DAILY_CITY_METRICS_SP()",
warehouse=warehouse_name,
)

dag_task2 >> dag_task1

# Create the DAG in Snowflake
dag_op = DAGOperation(schema)
dag_op.deploy(dag, mode="orreplace")

dagiter = dag_op.iter_dags(like='hol_dag%')
dagiter = dag_op.iter_dags(like="hol_dag%")
for dag_name in dagiter:
print(dag_name)

Expand All @@ -50,6 +64,6 @@ def main(session: Session) -> str:

# For local debugging
# Be aware you may need to type-convert arguments if you add input parameters
if __name__ == '__main__':
if __name__ == "__main__":
with Session.builder.getOrCreate() as session:
main(session)

0 comments on commit d6784e8

Please sign in to comment.