Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowflake IO Manager handle pandas timestamps #8760

Merged
merged 24 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Mapping, Union, cast

import pandas as pd
from dagster_snowflake import DbTypeHandler
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import SnowflakeDbClient, TableSlice
from pandas import DataFrame, read_sql
from snowflake.connector.pandas_tools import pd_writer

from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
Expand All @@ -21,7 +21,46 @@ def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice:
).get_connection(raw_conn=False)


class SnowflakePandasTypeHandler(DbTypeHandler[DataFrame]):
def _convert_timestamp_to_string(s: pd.Series) -> pd.Series:
"""
Converts columns of data of type pd.Timestamp to string so that it can be stored in
snowflake
"""
if pd.core.dtypes.common.is_datetime_or_timedelta_dtype(s):
# return s.dt.strftime("%Y-%m-%d %H:%M:%S.%f %z")
return None
else:
return s


def _get_timestamp_data(s: pd.Series) -> pd.Series:
"""
Converts columns of data of type pd.Timestamp to string so that it can be stored in
snowflake
"""
if pd.core.dtypes.common.is_datetime_or_timedelta_dtype(s):
# return s.dt.strftime("%Y-%m-%d %H:%M:%S.%f %z")
return s.dt.to_pydatetime()


def _convert_string_to_timestamp(s: pd.Series) -> pd.Series:
"""
Converts columns of strings in Timestamp format to pd.Timestamp to undo the conversion in
_convert_timestamp_to_string

This will not convert non-timestamp strings into timestamps (pd.to_datetime with raise an
jamiedemaria marked this conversation as resolved.
Show resolved Hide resolved
exception if the string cannot be converted)
"""
if isinstance(s[0], str):
try:
return pd.to_datetime(s.values)
except ValueError:
return s
else:
return s


class SnowflakePandasTypeHandler(DbTypeHandler[pd.DataFrame]):
"""
Defines how to translate between slices of Snowflake tables and Pandas DataFrames.

Expand All @@ -40,13 +79,16 @@ def my_job():
"""

def handle_output(
self, context: OutputContext, table_slice: TableSlice, obj: DataFrame
self, context: OutputContext, table_slice: TableSlice, obj: pd.DataFrame
) -> Mapping[str, RawMetadataValue]:
from snowflake import connector # pylint: disable=no-name-in-module

connector.paramstyle = "pyformat"
with _connect_snowflake(context, table_slice) as con:
with_uppercase_cols = obj.rename(str.upper, copy=False, axis="columns")
with_uppercase_cols = with_uppercase_cols.apply(
_convert_timestamp_to_string, axis="index"
)
with_uppercase_cols.to_sql(
table_slice.table,
con=con.engine,
Expand All @@ -55,6 +97,18 @@ def handle_output(
method=pd_writer,
)

for c in obj:
if pd.core.dtypes.common.is_datetime_or_timedelta_dtype(c):
converted = c.dt.to_pydatetime()
con.execute(
"INSERT INTO {%s}({%s}) values(%s)",
(
table_slice.table,
"DATE",
converted,
),
)

return {
"row_count": obj.shape[0],
"dataframe_columns": MetadataValue.table_schema(
Expand All @@ -67,12 +121,13 @@ def handle_output(
),
}

def load_input(self, context: InputContext, table_slice: TableSlice) -> DataFrame:
def load_input(self, context: InputContext, table_slice: TableSlice) -> pd.DataFrame:
with _connect_snowflake(context, table_slice) as con:
result = read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = pd.read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = result.apply(_convert_string_to_timestamp, axis="index")
result.columns = map(str.lower, result.columns)
return result

@property
def supported_types(self):
return [DataFrame]
return [pd.DataFrame]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import TableSlice
from dagster_snowflake_pandas import SnowflakePandasTypeHandler
from dagster_snowflake_pandas.snowflake_pandas_type_handler import (
_convert_string_to_timestamp,
_convert_timestamp_to_string,
)
from pandas import DataFrame

from dagster import (
Expand Down Expand Up @@ -84,7 +88,7 @@ def test_handle_output():

def test_load_input():
with patch("dagster_snowflake_pandas.snowflake_pandas_type_handler._connect_snowflake"), patch(
"dagster_snowflake_pandas.snowflake_pandas_type_handler.read_sql"
"dagster_snowflake_pandas.snowflake_pandas_type_handler.pd.read_sql"
) as mock_read_sql:
mock_read_sql.return_value = DataFrame([{"COL1": "a", "COL2": 1}])

Expand All @@ -104,6 +108,31 @@ def test_load_input():
assert df.equals(DataFrame([{"col1": "a", "col2": 1}]))


def test_type_conversions():
# no timestamp data
no_time = pandas.Series([1, 2, 3, 4, 5])
converted = _convert_string_to_timestamp(_convert_timestamp_to_string(no_time))

assert (converted == no_time).all()

# timestamp data
with_time = pandas.Series(
[
pandas.Timestamp("2017-01-01T12:30:45.35"),
pandas.Timestamp("2017-02-01T12:30:45.35"),
pandas.Timestamp("2017-03-01T12:30:45.35"),
]
)
time_converted = _convert_string_to_timestamp(_convert_timestamp_to_string(with_time))

assert (with_time == time_converted).all()

# string that isn't a time
string_data = pandas.Series(["not", "a", "timestamp"])

assert (_convert_string_to_timestamp(string_data) == string_data).all()


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_io_manager_with_snowflake_pandas():
with temporary_snowflake_table(
Expand All @@ -128,6 +157,31 @@ def read_pandas_df(df: pandas.DataFrame):
assert set(df.columns) == {"foo", "quux"}
assert len(df.index) == 2

time_df = pandas.DataFrame(
{
"foo": ["bar", "baz"],
"date": [
pandas.Timestamp("2017-01-01T12:30:45.35"),
pandas.Timestamp("2017-02-01T12:30:45.35"),
],
}
)

@op(
out={
table_name: Out(
io_manager_key="snowflake", metadata={"schema": "SNOWFLAKE_IO_MANAGER_SCHEMA"}
)
}
)
def emit_time_df(_):
return time_df

@op
def read_time_df(df: pandas.DataFrame):
assert set(df.columns) == {"foo", "date"}
assert (df == time_df).all()

snowflake_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()])

@job(
Expand All @@ -145,6 +199,7 @@ def read_pandas_df(df: pandas.DataFrame):
)
def io_manager_test_pipeline():
read_pandas_df(emit_pandas_df())
read_time_df(emit_time_df())

res = io_manager_test_pipeline.execute_in_process()
assert res.success
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def _get_table_slice(
table = asset_key_path[-1]
if len(asset_key_path) > 1:
schema = asset_key_path[-2]
elif context.resource_config.get("schema"):
schema = context.resource_config["schema"]
else:
schema = "public"
time_window = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def my_job():
"user": StringSource,
"password": StringSource,
"warehouse": Field(StringSource, is_required=False),
"schema": Field(StringSource, is_required=False),
}
)
def snowflake_io_manager():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access
from unittest.mock import MagicMock

from dagster_snowflake import DbTypeHandler
Expand Down Expand Up @@ -199,3 +200,40 @@ def test_non_asset_out():

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == table_slice


def test_schema_defaults():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
manager = DbIOManager(type_handlers=[handler], db_client=db_client)

asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "schema1"

asset_key = AssetKey(["table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "public"

resource_config_w_schema = resource_config
resource_config_w_schema["schema"] = "my_schema"

asset_key = AssetKey(["table1"])
output_context = build_output_context(
asset_key=asset_key, resource_config=resource_config_w_schema
)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "my_schema"

asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(
asset_key=asset_key, resource_config=resource_config_w_schema
)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "schema1"