Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowflake IO Manager handle pandas timestamps #8760

Merged
merged 24 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime
from typing import Mapping, Union, cast

import pandas as pd
from dagster_snowflake import DbTypeHandler
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import SnowflakeDbClient, TableSlice
from pandas import DataFrame, read_sql
from snowflake.connector.pandas_tools import pd_writer

from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
Expand All @@ -21,7 +22,25 @@ def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice:
).get_connection(raw_conn=False)


class SnowflakePandasTypeHandler(DbTypeHandler[DataFrame]):
def _convert_timestamp_to_date(s: pd.Series) -> pd.Series:
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts columns of data of type pd.Timezone to datetime.date so that it can be stored in
snowflake
"""
if pd.core.dtypes.common.is_datetime_or_timedelta_dtype(s):
return s.dt.to_pydatetime()
else:
return s


def _convert_date_to_timestamp(s: pd.Series) -> pd.Series:
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(s[0], datetime.date):
return pd.to_datetime(s.values)
else:
return s


class SnowflakePandasTypeHandler(DbTypeHandler[pd.DataFrame]):
"""
Defines how to translate between slices of Snowflake tables and Pandas DataFrames.

Expand All @@ -40,13 +59,14 @@ def my_job():
"""

def handle_output(
self, context: OutputContext, table_slice: TableSlice, obj: DataFrame
self, context: OutputContext, table_slice: TableSlice, obj: pd.DataFrame
) -> Mapping[str, RawMetadataValue]:
from snowflake import connector # pylint: disable=no-name-in-module

connector.paramstyle = "pyformat"
with _connect_snowflake(context, table_slice) as con:
with_uppercase_cols = obj.rename(str.upper, copy=False, axis="columns")
with_uppercase_cols = with_uppercase_cols.apply(_convert_timestamp_to_date, axis=0)
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
with_uppercase_cols.to_sql(
table_slice.table,
con=con.engine,
Expand All @@ -67,12 +87,13 @@ def handle_output(
),
}

def load_input(self, context: InputContext, table_slice: TableSlice) -> DataFrame:
def load_input(self, context: InputContext, table_slice: TableSlice) -> pd.DataFrame:
with _connect_snowflake(context, table_slice) as con:
result = read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = pd.read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = result.apply(_convert_date_to_timestamp, axis=0)
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
result.columns = map(str.lower, result.columns)
return result

@property
def supported_types(self):
return [DataFrame]
return [pd.DataFrame]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import TableSlice
from dagster_snowflake_pandas import SnowflakePandasTypeHandler
from dagster_snowflake_pandas.snowflake_pandas_type_handler import (
_convert_date_to_timestamp,
_convert_timestamp_to_date,
)
from pandas import DataFrame

from dagster import (
Expand Down Expand Up @@ -84,7 +88,7 @@ def test_handle_output():

def test_load_input():
with patch("dagster_snowflake_pandas.snowflake_pandas_type_handler._connect_snowflake"), patch(
"dagster_snowflake_pandas.snowflake_pandas_type_handler.read_sql"
"dagster_snowflake_pandas.snowflake_pandas_type_handler.pd.read_sql"
) as mock_read_sql:
mock_read_sql.return_value = DataFrame([{"COL1": "a", "COL2": 1}])

Expand All @@ -104,6 +108,28 @@ def test_load_input():
assert df.equals(DataFrame([{"col1": "a", "col2": 1}]))


def test_type_conversions():
# no timestamp data
no_time = pandas.Series([1, 2, 3, 4, 5])
converted = _convert_date_to_timestamp(_convert_timestamp_to_date(no_time))

assert (converted == no_time).all()

# timestamp data
with_time = pandas.Series(
[
pandas.Timestamp("2017-01-01T12"),
pandas.Timestamp("2017-02-01T12"),
pandas.Timestamp("2017-03-01T12"),
]
)
time_converted = _convert_date_to_timestamp(
pandas.Series(_convert_timestamp_to_date(with_time))
)

assert (with_time == time_converted).all()


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_io_manager_with_snowflake_pandas():
with temporary_snowflake_table(
Expand All @@ -128,6 +154,28 @@ def read_pandas_df(df: pandas.DataFrame):
assert set(df.columns) == {"foo", "quux"}
assert len(df.index) == 2

time_df = pandas.DataFrame(
{
"foo": ["bar", "baz"],
"date": [pandas.Timestamp("2017-01-01T12"), pandas.Timestamp("2017-02-01T12")],
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
}
)

@op(
out={
table_name: Out(
io_manager_key="snowflake", metadata={"schema": "SNOWFLAKE_IO_MANAGER_SCHEMA"}
)
}
)
def emit_time_df(_):
return time_df

@op
def read_time_df(df: pandas.DataFrame):
assert set(df.columns) == {"foo", "date"}
assert (df == time_df).all()

snowflake_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()])

@job(
Expand All @@ -145,6 +193,7 @@ def read_pandas_df(df: pandas.DataFrame):
)
def io_manager_test_pipeline():
read_pandas_df(emit_pandas_df())
read_time_df(emit_time_df())

res = io_manager_test_pipeline.execute_in_process()
assert res.success