Skip to content

Commit

Permalink
Snowflake IO Manager handles pandas timestamps (#8760)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria authored Jul 8, 2022
1 parent 5938830 commit b721c70
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Mapping, Union, cast

import pandas as pd
from dagster_snowflake import DbTypeHandler
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import SnowflakeDbClient, TableSlice
from pandas import DataFrame, read_sql
from snowflake.connector.pandas_tools import pd_writer

from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema
Expand All @@ -21,7 +21,35 @@ def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice:
).get_connection(raw_conn=False)


class SnowflakePandasTypeHandler(DbTypeHandler[DataFrame]):
def _convert_timestamp_to_string(s: pd.Series) -> pd.Series:
"""
Converts columns of data of type pd.Timestamp to string so that it can be stored in
snowflake
"""
if pd.core.dtypes.common.is_datetime_or_timedelta_dtype(s):
return s.dt.strftime("%Y-%m-%d %H:%M:%S.%f %z")
else:
return s


def _convert_string_to_timestamp(s: pd.Series) -> pd.Series:
"""
Converts columns of strings in Timestamp format to pd.Timestamp to undo the conversion in
_convert_timestamp_to_string
This will not convert non-timestamp strings into timestamps (pd.to_datetime will raise an
exception if the string cannot be converted)
"""
if isinstance(s[0], str):
try:
return pd.to_datetime(s.values)
except ValueError:
return s
else:
return s


class SnowflakePandasTypeHandler(DbTypeHandler[pd.DataFrame]):
"""
Defines how to translate between slices of Snowflake tables and Pandas DataFrames.
Expand All @@ -40,13 +68,16 @@ def my_job():
"""

def handle_output(
self, context: OutputContext, table_slice: TableSlice, obj: DataFrame
self, context: OutputContext, table_slice: TableSlice, obj: pd.DataFrame
) -> Mapping[str, RawMetadataValue]:
from snowflake import connector # pylint: disable=no-name-in-module

connector.paramstyle = "pyformat"
with _connect_snowflake(context, table_slice) as con:
with_uppercase_cols = obj.rename(str.upper, copy=False, axis="columns")
with_uppercase_cols = with_uppercase_cols.apply(
_convert_timestamp_to_string, axis="index"
)
with_uppercase_cols.to_sql(
table_slice.table,
con=con.engine,
Expand All @@ -67,12 +98,13 @@ def handle_output(
),
}

def load_input(self, context: InputContext, table_slice: TableSlice) -> DataFrame:
def load_input(self, context: InputContext, table_slice: TableSlice) -> pd.DataFrame:
with _connect_snowflake(context, table_slice) as con:
result = read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = pd.read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con)
result = result.apply(_convert_string_to_timestamp, axis="index")
result.columns = map(str.lower, result.columns)
return result

@property
def supported_types(self):
return [DataFrame]
return [pd.DataFrame]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from dagster_snowflake.resources import SnowflakeConnection
from dagster_snowflake.snowflake_io_manager import TableSlice
from dagster_snowflake_pandas import SnowflakePandasTypeHandler
from dagster_snowflake_pandas.snowflake_pandas_type_handler import (
_convert_string_to_timestamp,
_convert_timestamp_to_string,
)
from pandas import DataFrame

from dagster import (
Expand Down Expand Up @@ -43,13 +47,13 @@


@contextmanager
def temporary_snowflake_table(schema_name: str, db_name: str) -> Iterator[str]:
def temporary_snowflake_table(schema_name: str, db_name: str, column_str: str) -> Iterator[str]:
snowflake_config = dict(database=db_name, **SHARED_BUILDKITE_SNOWFLAKE_CONF)
table_name = "test_io_manager_" + str(uuid.uuid4()).replace("-", "_")
with SnowflakeConnection(
snowflake_config, logging.getLogger("temporary_snowflake_table")
).get_connection() as conn:
conn.cursor().execute(f"create table {schema_name}.{table_name} (foo string, quux integer)")
conn.cursor().execute(f"create table {schema_name}.{table_name} ({column_str})")
try:
yield table_name
finally:
Expand Down Expand Up @@ -84,7 +88,7 @@ def test_handle_output():

def test_load_input():
with patch("dagster_snowflake_pandas.snowflake_pandas_type_handler._connect_snowflake"), patch(
"dagster_snowflake_pandas.snowflake_pandas_type_handler.read_sql"
"dagster_snowflake_pandas.snowflake_pandas_type_handler.pd.read_sql"
) as mock_read_sql:
mock_read_sql.return_value = DataFrame([{"COL1": "a", "COL2": 1}])

Expand All @@ -104,10 +108,37 @@ def test_load_input():
assert df.equals(DataFrame([{"col1": "a", "col2": 1}]))


def test_type_conversions():
# no timestamp data
no_time = pandas.Series([1, 2, 3, 4, 5])
converted = _convert_string_to_timestamp(_convert_timestamp_to_string(no_time))

assert (converted == no_time).all()

# timestamp data
with_time = pandas.Series(
[
pandas.Timestamp("2017-01-01T12:30:45.35"),
pandas.Timestamp("2017-02-01T12:30:45.35"),
pandas.Timestamp("2017-03-01T12:30:45.35"),
]
)
time_converted = _convert_string_to_timestamp(_convert_timestamp_to_string(with_time))

assert (with_time == time_converted).all()

# string that isn't a time
string_data = pandas.Series(["not", "a", "timestamp"])

assert (_convert_string_to_timestamp(string_data) == string_data).all()


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_io_manager_with_snowflake_pandas():
with temporary_snowflake_table(
schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA", db_name="TEST_SNOWFLAKE_IO_MANAGER"
schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA",
db_name="TEST_SNOWFLAKE_IO_MANAGER",
column_str="foo string, quux integer",
) as table_name:

# Create a job with the temporary table name as an output, so that it will write to that table
Expand Down Expand Up @@ -148,3 +179,58 @@ def io_manager_test_pipeline():

res = io_manager_test_pipeline.execute_in_process()
assert res.success


@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB")
def test_io_manager_with_snowflake_pandas_timestamp_data():
with temporary_snowflake_table(
schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA",
db_name="TEST_SNOWFLAKE_IO_MANAGER",
column_str="foo string, date TIMESTAMP_NTZ(9)",
) as table_name:

time_df = pandas.DataFrame(
{
"foo": ["bar", "baz"],
"date": [
pandas.Timestamp("2017-01-01T12:30:45.350"),
pandas.Timestamp("2017-02-01T12:30:45.350"),
],
}
)

@op(
out={
table_name: Out(
io_manager_key="snowflake", metadata={"schema": "SNOWFLAKE_IO_MANAGER_SCHEMA"}
)
}
)
def emit_time_df(_):
return time_df

@op
def read_time_df(df: pandas.DataFrame):
assert set(df.columns) == {"foo", "date"}
assert (df["date"] == time_df["date"]).all()

snowflake_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()])

@job(
resource_defs={"snowflake": snowflake_io_manager},
config={
"resources": {
"snowflake": {
"config": {
**SHARED_BUILDKITE_SNOWFLAKE_CONF,
"database": "TEST_SNOWFLAKE_IO_MANAGER",
}
}
}
},
)
def io_manager_timestamp_test_job():
read_time_df(emit_time_df())

res = io_manager_timestamp_test_job.execute_in_process()
assert res.success
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dagster import IOManager, InputContext, OutputContext
from dagster.core.definitions.metadata import RawMetadataValue
from dagster.core.definitions.time_window_partitions import TimeWindow
from dagster.core.errors import DagsterInvalidDefinitionError

SNOWFLAKE_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"

Expand Down Expand Up @@ -118,16 +119,45 @@ def _get_table_slice(
if context.has_asset_key:
asset_key_path = context.asset_key.path
table = asset_key_path[-1]
if len(asset_key_path) > 1:
if (
len(asset_key_path) > 1
and context.resource_config
and context.resource_config.get("schema")
):
raise DagsterInvalidDefinitionError(
f"Asset {asset_key_path} specifies a schema with "
f"its key prefixes {asset_key_path[:-1]}, but schema "
f"{context.resource_config.get('schema')} was also provided via run config. "
"Schema can only be specified one way."
)
elif len(asset_key_path) > 1:
schema = asset_key_path[-2]
elif context.resource_config and context.resource_config.get("schema"):
schema = context.resource_config["schema"]
else:
schema = "public"
time_window = (
context.asset_partitions_time_window if context.has_asset_partitions else None
)
else:
table = output_context.name
schema = output_context_metadata.get("schema", "public")
if (
output_context_metadata.get("schema")
and output_context.resource_config
and output_context.resource_config.get("schema")
):
raise DagsterInvalidDefinitionError(
f"Schema {output_context_metadata.get('schema')} "
"specified via output metadata, but conflicting schema "
f"{output_context.resource_config.get('schema')} was provided via run_config. "
"Schema can only be specified one way."
)
elif output_context.resource_config and output_context_metadata.get("schema"):
schema = output_context_metadata["schema"]
elif output_context.resource_config and output_context.resource_config.get("schema"):
schema = output_context.resource_config["schema"]
else:
schema = "public"
time_window = None

if time_window is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def my_job():
"user": StringSource,
"password": StringSource,
"warehouse": Field(StringSource, is_required=False),
"schema": Field(StringSource, is_required=False),
}
)
def snowflake_io_manager():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=protected-access
from unittest.mock import MagicMock

import pytest
from dagster_snowflake import DbTypeHandler
from dagster_snowflake.db_io_manager import DbClient, DbIOManager, TablePartition, TableSlice
from pendulum import datetime

from dagster import AssetKey, InputContext, OutputContext, build_output_context
from dagster.core.definitions.time_window_partitions import TimeWindow
from dagster.core.errors import DagsterInvalidDefinitionError
from dagster.core.types.dagster_type import resolve_dagster_type

resource_config = {
Expand Down Expand Up @@ -199,3 +202,82 @@ def test_non_asset_out():

assert len(handler.handle_input_calls) == 1
assert handler.handle_input_calls[0][1] == table_slice


def test_asset_schema_defaults():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
manager = DbIOManager(type_handlers=[handler], db_client=db_client)

asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "schema1"

asset_key = AssetKey(["table1"])
output_context = build_output_context(asset_key=asset_key, resource_config=resource_config)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "public"

resource_config_w_schema = {
"database": "database_abc",
"account": "account_abc",
"user": "user_abc",
"password": "password_abc",
"warehouse": "warehouse_abc",
"schema": "my_schema",
}

asset_key = AssetKey(["table1"])
output_context = build_output_context(
asset_key=asset_key, resource_config=resource_config_w_schema
)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "my_schema"

asset_key = AssetKey(["schema1", "table1"])
output_context = build_output_context(
asset_key=asset_key, resource_config=resource_config_w_schema
)
with pytest.raises(DagsterInvalidDefinitionError):
table_slice = manager._get_table_slice(output_context, output_context)


def test_output_schema_defaults():
handler = IntHandler()
db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value=""))
manager = DbIOManager(type_handlers=[handler], db_client=db_client)
output_context = build_output_context(
name="table1", metadata={"schema": "schema1"}, resource_config=resource_config
)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "schema1"

output_context = build_output_context(name="table1", resource_config=resource_config)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "public"

resource_config_w_schema = {
"database": "database_abc",
"account": "account_abc",
"user": "user_abc",
"password": "password_abc",
"warehouse": "warehouse_abc",
"schema": "my_schema",
}

output_context = build_output_context(name="table1", resource_config=resource_config_w_schema)
table_slice = manager._get_table_slice(output_context, output_context)

assert table_slice.schema == "my_schema"

output_context = build_output_context(
name="table1", metadata={"schema": "schema1"}, resource_config=resource_config_w_schema
)
with pytest.raises(DagsterInvalidDefinitionError):
table_slice = manager._get_table_slice(output_context, output_context)

0 comments on commit b721c70

Please sign in to comment.