diff --git a/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas/snowflake_pandas_type_handler.py b/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas/snowflake_pandas_type_handler.py index fa5566c50a067..da2d44dc17875 100644 --- a/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas/snowflake_pandas_type_handler.py +++ b/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas/snowflake_pandas_type_handler.py @@ -1,9 +1,9 @@ from typing import Mapping, Union, cast +import pandas as pd from dagster_snowflake import DbTypeHandler from dagster_snowflake.resources import SnowflakeConnection from dagster_snowflake.snowflake_io_manager import SnowflakeDbClient, TableSlice -from pandas import DataFrame, read_sql from snowflake.connector.pandas_tools import pd_writer from dagster import InputContext, MetadataValue, OutputContext, TableColumn, TableSchema @@ -21,7 +21,35 @@ def _connect_snowflake(context: Union[InputContext, OutputContext], table_slice: ).get_connection(raw_conn=False) -class SnowflakePandasTypeHandler(DbTypeHandler[DataFrame]): +def _convert_timestamp_to_string(s: pd.Series) -> pd.Series: + """ + Converts columns of data of type pd.Timestamp to string so that it can be stored in + snowflake + """ + if pd.core.dtypes.common.is_datetime_or_timedelta_dtype(s): + return s.dt.strftime("%Y-%m-%d %H:%M:%S.%f %z") + else: + return s + + +def _convert_string_to_timestamp(s: pd.Series) -> pd.Series: + """ + Converts columns of strings in Timestamp format to pd.Timestamp to undo the conversion in + _convert_timestamp_to_string + + This will not convert non-timestamp strings into timestamps (pd.to_datetime will raise an + exception if the string cannot be converted) + """ + if isinstance(s[0], str): + try: + return pd.to_datetime(s.values) + except ValueError: + return s + else: + return s + + +class SnowflakePandasTypeHandler(DbTypeHandler[pd.DataFrame]): """ Defines how to translate between slices of Snowflake tables and Pandas DataFrames. @@ -40,13 +68,16 @@ def my_job(): """ def handle_output( - self, context: OutputContext, table_slice: TableSlice, obj: DataFrame + self, context: OutputContext, table_slice: TableSlice, obj: pd.DataFrame ) -> Mapping[str, RawMetadataValue]: from snowflake import connector # pylint: disable=no-name-in-module connector.paramstyle = "pyformat" with _connect_snowflake(context, table_slice) as con: with_uppercase_cols = obj.rename(str.upper, copy=False, axis="columns") + with_uppercase_cols = with_uppercase_cols.apply( + _convert_timestamp_to_string, axis="index" + ) with_uppercase_cols.to_sql( table_slice.table, con=con.engine, @@ -67,12 +98,13 @@ def handle_output( ), } - def load_input(self, context: InputContext, table_slice: TableSlice) -> DataFrame: + def load_input(self, context: InputContext, table_slice: TableSlice) -> pd.DataFrame: with _connect_snowflake(context, table_slice) as con: - result = read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con) + result = pd.read_sql(sql=SnowflakeDbClient.get_select_statement(table_slice), con=con) + result = result.apply(_convert_string_to_timestamp, axis="index") result.columns = map(str.lower, result.columns) return result @property def supported_types(self): - return [DataFrame] + return [pd.DataFrame] diff --git a/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py b/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py index 556d20bb3a237..045f02682b567 100644 --- a/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py +++ b/python_modules/libraries/dagster-snowflake-pandas/dagster_snowflake_pandas_tests/test_snowflake_pandas_type_handler.py @@ -11,6 +11,10 @@ from dagster_snowflake.resources import SnowflakeConnection from dagster_snowflake.snowflake_io_manager import TableSlice from dagster_snowflake_pandas import SnowflakePandasTypeHandler +from dagster_snowflake_pandas.snowflake_pandas_type_handler import ( + _convert_string_to_timestamp, + _convert_timestamp_to_string, +) from pandas import DataFrame from dagster import ( @@ -43,13 +47,13 @@ @contextmanager -def temporary_snowflake_table(schema_name: str, db_name: str) -> Iterator[str]: +def temporary_snowflake_table(schema_name: str, db_name: str, column_str: str) -> Iterator[str]: snowflake_config = dict(database=db_name, **SHARED_BUILDKITE_SNOWFLAKE_CONF) table_name = "test_io_manager_" + str(uuid.uuid4()).replace("-", "_") with SnowflakeConnection( snowflake_config, logging.getLogger("temporary_snowflake_table") ).get_connection() as conn: - conn.cursor().execute(f"create table {schema_name}.{table_name} (foo string, quux integer)") + conn.cursor().execute(f"create table {schema_name}.{table_name} ({column_str})") try: yield table_name finally: @@ -84,7 +88,7 @@ def test_handle_output(): def test_load_input(): with patch("dagster_snowflake_pandas.snowflake_pandas_type_handler._connect_snowflake"), patch( - "dagster_snowflake_pandas.snowflake_pandas_type_handler.read_sql" + "dagster_snowflake_pandas.snowflake_pandas_type_handler.pd.read_sql" ) as mock_read_sql: mock_read_sql.return_value = DataFrame([{"COL1": "a", "COL2": 1}]) @@ -104,10 +108,37 @@ def test_load_input(): assert df.equals(DataFrame([{"col1": "a", "col2": 1}])) +def test_type_conversions(): + # no timestamp data + no_time = pandas.Series([1, 2, 3, 4, 5]) + converted = _convert_string_to_timestamp(_convert_timestamp_to_string(no_time)) + + assert (converted == no_time).all() + + # timestamp data + with_time = pandas.Series( + [ + pandas.Timestamp("2017-01-01T12:30:45.35"), + pandas.Timestamp("2017-02-01T12:30:45.35"), + pandas.Timestamp("2017-03-01T12:30:45.35"), + ] + ) + time_converted = _convert_string_to_timestamp(_convert_timestamp_to_string(with_time)) + + assert (with_time == time_converted).all() + + # string that isn't a time + string_data = pandas.Series(["not", "a", "timestamp"]) + + assert (_convert_string_to_timestamp(string_data) == string_data).all() + + @pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") def test_io_manager_with_snowflake_pandas(): with temporary_snowflake_table( - schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA", db_name="TEST_SNOWFLAKE_IO_MANAGER" + schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA", + db_name="TEST_SNOWFLAKE_IO_MANAGER", + column_str="foo string, quux integer", ) as table_name: # Create a job with the temporary table name as an output, so that it will write to that table @@ -148,3 +179,58 @@ def io_manager_test_pipeline(): res = io_manager_test_pipeline.execute_in_process() assert res.success + + +@pytest.mark.skipif(not IS_BUILDKITE, reason="Requires access to the BUILDKITE snowflake DB") +def test_io_manager_with_snowflake_pandas_timestamp_data(): + with temporary_snowflake_table( + schema_name="SNOWFLAKE_IO_MANAGER_SCHEMA", + db_name="TEST_SNOWFLAKE_IO_MANAGER", + column_str="foo string, date TIMESTAMP_NTZ(9)", + ) as table_name: + + time_df = pandas.DataFrame( + { + "foo": ["bar", "baz"], + "date": [ + pandas.Timestamp("2017-01-01T12:30:45.350"), + pandas.Timestamp("2017-02-01T12:30:45.350"), + ], + } + ) + + @op( + out={ + table_name: Out( + io_manager_key="snowflake", metadata={"schema": "SNOWFLAKE_IO_MANAGER_SCHEMA"} + ) + } + ) + def emit_time_df(_): + return time_df + + @op + def read_time_df(df: pandas.DataFrame): + assert set(df.columns) == {"foo", "date"} + assert (df["date"] == time_df["date"]).all() + + snowflake_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()]) + + @job( + resource_defs={"snowflake": snowflake_io_manager}, + config={ + "resources": { + "snowflake": { + "config": { + **SHARED_BUILDKITE_SNOWFLAKE_CONF, + "database": "TEST_SNOWFLAKE_IO_MANAGER", + } + } + } + }, + ) + def io_manager_timestamp_test_job(): + read_time_df(emit_time_df()) + + res = io_manager_timestamp_test_job.execute_in_process() + assert res.success diff --git a/python_modules/libraries/dagster-snowflake/dagster_snowflake/db_io_manager.py b/python_modules/libraries/dagster-snowflake/dagster_snowflake/db_io_manager.py index fb5795ceceaec..ba0121affe196 100644 --- a/python_modules/libraries/dagster-snowflake/dagster_snowflake/db_io_manager.py +++ b/python_modules/libraries/dagster-snowflake/dagster_snowflake/db_io_manager.py @@ -16,6 +16,7 @@ from dagster import IOManager, InputContext, OutputContext from dagster.core.definitions.metadata import RawMetadataValue from dagster.core.definitions.time_window_partitions import TimeWindow +from dagster.core.errors import DagsterInvalidDefinitionError SNOWFLAKE_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -118,8 +119,21 @@ def _get_table_slice( if context.has_asset_key: asset_key_path = context.asset_key.path table = asset_key_path[-1] - if len(asset_key_path) > 1: + if ( + len(asset_key_path) > 1 + and context.resource_config + and context.resource_config.get("schema") + ): + raise DagsterInvalidDefinitionError( + f"Asset {asset_key_path} specifies a schema with " + f"its key prefixes {asset_key_path[:-1]}, but schema " + f"{context.resource_config.get('schema')} was also provided via run config. " + "Schema can only be specified one way." + ) + elif len(asset_key_path) > 1: schema = asset_key_path[-2] + elif context.resource_config and context.resource_config.get("schema"): + schema = context.resource_config["schema"] else: schema = "public" time_window = ( @@ -127,7 +141,23 @@ def _get_table_slice( ) else: table = output_context.name - schema = output_context_metadata.get("schema", "public") + if ( + output_context_metadata.get("schema") + and output_context.resource_config + and output_context.resource_config.get("schema") + ): + raise DagsterInvalidDefinitionError( + f"Schema {output_context_metadata.get('schema')} " + "specified via output metadata, but conflicting schema " + f"{output_context.resource_config.get('schema')} was provided via run_config. " + "Schema can only be specified one way." + ) + elif output_context.resource_config and output_context_metadata.get("schema"): + schema = output_context_metadata["schema"] + elif output_context.resource_config and output_context.resource_config.get("schema"): + schema = output_context.resource_config["schema"] + else: + schema = "public" time_window = None if time_window is not None: diff --git a/python_modules/libraries/dagster-snowflake/dagster_snowflake/snowflake_io_manager.py b/python_modules/libraries/dagster-snowflake/dagster_snowflake/snowflake_io_manager.py index 522925d67190c..bdd5ea7ca5a24 100644 --- a/python_modules/libraries/dagster-snowflake/dagster_snowflake/snowflake_io_manager.py +++ b/python_modules/libraries/dagster-snowflake/dagster_snowflake/snowflake_io_manager.py @@ -40,6 +40,7 @@ def my_job(): "user": StringSource, "password": StringSource, "warehouse": Field(StringSource, is_required=False), + "schema": Field(StringSource, is_required=False), } ) def snowflake_io_manager(): diff --git a/python_modules/libraries/dagster-snowflake/dagster_snowflake_tests/test_db_io_manager.py b/python_modules/libraries/dagster-snowflake/dagster_snowflake_tests/test_db_io_manager.py index 779109bf15cf1..a7433eaeede43 100644 --- a/python_modules/libraries/dagster-snowflake/dagster_snowflake_tests/test_db_io_manager.py +++ b/python_modules/libraries/dagster-snowflake/dagster_snowflake_tests/test_db_io_manager.py @@ -1,11 +1,14 @@ +# pylint: disable=protected-access from unittest.mock import MagicMock +import pytest from dagster_snowflake import DbTypeHandler from dagster_snowflake.db_io_manager import DbClient, DbIOManager, TablePartition, TableSlice from pendulum import datetime from dagster import AssetKey, InputContext, OutputContext, build_output_context from dagster.core.definitions.time_window_partitions import TimeWindow +from dagster.core.errors import DagsterInvalidDefinitionError from dagster.core.types.dagster_type import resolve_dagster_type resource_config = { @@ -199,3 +202,82 @@ def test_non_asset_out(): assert len(handler.handle_input_calls) == 1 assert handler.handle_input_calls[0][1] == table_slice + + +def test_asset_schema_defaults(): + handler = IntHandler() + db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value="")) + manager = DbIOManager(type_handlers=[handler], db_client=db_client) + + asset_key = AssetKey(["schema1", "table1"]) + output_context = build_output_context(asset_key=asset_key, resource_config=resource_config) + table_slice = manager._get_table_slice(output_context, output_context) + + assert table_slice.schema == "schema1" + + asset_key = AssetKey(["table1"]) + output_context = build_output_context(asset_key=asset_key, resource_config=resource_config) + table_slice = manager._get_table_slice(output_context, output_context) + + assert table_slice.schema == "public" + + resource_config_w_schema = { + "database": "database_abc", + "account": "account_abc", + "user": "user_abc", + "password": "password_abc", + "warehouse": "warehouse_abc", + "schema": "my_schema", + } + + asset_key = AssetKey(["table1"]) + output_context = build_output_context( + asset_key=asset_key, resource_config=resource_config_w_schema + ) + table_slice = manager._get_table_slice(output_context, output_context) + + assert table_slice.schema == "my_schema" + + asset_key = AssetKey(["schema1", "table1"]) + output_context = build_output_context( + asset_key=asset_key, resource_config=resource_config_w_schema + ) + with pytest.raises(DagsterInvalidDefinitionError): + table_slice = manager._get_table_slice(output_context, output_context) + + +def test_output_schema_defaults(): + handler = IntHandler() + db_client = MagicMock(spec=DbClient, get_select_statement=MagicMock(return_value="")) + manager = DbIOManager(type_handlers=[handler], db_client=db_client) + output_context = build_output_context( + name="table1", metadata={"schema": "schema1"}, resource_config=resource_config + ) + table_slice = manager._get_table_slice(output_context, output_context) + + assert table_slice.schema == "schema1" + + output_context = build_output_context(name="table1", resource_config=resource_config) + table_slice = manager._get_table_slice(output_context, output_context) + + assert table_slice.schema == "public" + + resource_config_w_schema = { + "database": "database_abc", + "account": "account_abc", + "user": "user_abc", + "password": "password_abc", + "warehouse": "warehouse_abc", + "schema": "my_schema", + } + + output_context = build_output_context(name="table1", resource_config=resource_config_w_schema) + table_slice = manager._get_table_slice(output_context, output_context) + + assert table_slice.schema == "my_schema" + + output_context = build_output_context( + name="table1", metadata={"schema": "schema1"}, resource_config=resource_config_w_schema + ) + with pytest.raises(DagsterInvalidDefinitionError): + table_slice = manager._get_table_slice(output_context, output_context)