Skip to content

Commit

Permalink
Added timezone type to dfs when the corresponding pandas dataframe al…
Browse files Browse the repository at this point in the history
…so has timezone
  • Loading branch information
frederiksteinerSBB committed May 24, 2024
1 parent 5e61c94 commit 92808e5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/snowflake/connector/pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,15 @@ def write_pandas(
# if the column name contains a double quote, we need to escape it by replacing with two double quotes
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers
snowflake_column_names = [str(c).replace('"', '""') for c in df.columns]
tz_columns = [
str(c).replace('"', '""') for c in df.columns if pandas.api.types.is_datetime64tz_dtype(df[c])
]
else:
quote = ""
snowflake_column_names = list(df.columns)
tz_columns = [
c for c in df.columns if pandas.api.types.is_datetime64tz_dtype(df[c])
]
columns = quote + f"{quote},{quote}".join(snowflake_column_names) + quote

def drop_object(name: str, object_type: str) -> None:
Expand All @@ -376,6 +382,7 @@ def drop_object(name: str, object_type: str) -> None:
column_type_mapping = dict(
cursor.execute(infer_schema_sql, _is_internal=True).fetchall()
)
column_type_mapping.update({c: "TIMESTAMP_TZ" for c in tz_columns})
# Infer schema can return the columns out of order depending on the chunking we do when uploading
# so we have to iterate through the dataframe columns to make sure we create the table with its
# columns in order
Expand Down
1 change: 1 addition & 0 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def test_write_pandas_use_logical_type(
write_pandas(**write_pandas_kwargs)
df_read = cnx.cursor().execute(select_sql).fetch_pandas_all()
assert all(df_write == df_read)
assert pandas.api.types.is_datetime64tz_dtype(df_read[col_name])
# For other use_logical_type values, a UserWarning should be displayed.
else:
with pytest.warns(UserWarning, match="Dataframe contains a datetime.*"):
Expand Down

0 comments on commit 92808e5

Please sign in to comment.