Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1877449:Exception should be thrown when create df with null value and nullable set to False #2849

Merged
merged 9 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
- Fixed a bug in local testing mode that caused a column to contain None when it should contain 0
- Fixed a bug in `StructField.from_json` that prevented TimestampTypes with `tzinfo` from being parsed correctly.
- Fixed a bug in function `date_format` that caused an error when the input column was date type or timestamp type.
- Fixed a bug in dataframe that null value can be inserted in a non-nullable column.
- Fixed a bug in `replace` when passing `Column` expression objects.

### Snowpark pandas API Updates
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union

from snowflake.connector import IntegrityError

import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
alias_expression,
Expand Down Expand Up @@ -975,6 +977,8 @@ def do_resolve_with_resolved_children(

if logical_plan.data:
if not logical_plan.is_large_local_data:
if logical_plan.is_contain_illegal_null_value:
raise IntegrityError("NULL result in a non-nullable column")
return self.plan_builder.query(
values_statement(logical_plan.output, logical_plan.data),
logical_plan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ def is_large_local_data(self) -> bool:

return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD

@property
def is_contain_illegal_null_value(self) -> bool:
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD

rows_to_compare = min(
ARRAY_BIND_THRESHOLD // len(self.output) + 1, len(self.data)
)
for j in range(len(self.output)):
if not self.output[j].nullable:
for i in range(rows_to_compare):
if self.data[i][j] is None:
return True
return False

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
if self.is_large_local_data:
Expand Down
25 changes: 25 additions & 0 deletions tests/integ/compiler/test_query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List

import pytest
from snowflake.connector import IntegrityError

from snowflake.snowpark import Window
from snowflake.snowpark._internal.analyzer import analyzer
Expand Down Expand Up @@ -51,6 +52,7 @@
random_name_for_temp_object,
)
from snowflake.snowpark.functions import avg, col, lit, when_matched
from snowflake.snowpark.types import StructType, StructField, LongType
from tests.integ.scala.test_dataframe_reader_suite import get_reader
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
from tests.utils import TestFiles, Utils
Expand Down Expand Up @@ -533,3 +535,26 @@ def test_select_alias(session):
# Add a new column d that doesn't use c after c was added previously. Flatten safely.
df2 = df1.select("a", "b", "c", (col("a") + col("b") + 1).as_("d"))
check_generated_plan_queries(df2._plan)


def test_nullable_is_false_dataframe(session):
from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD

schema = StructType([StructField("key", LongType(), nullable=True)])
assert session.create_dataframe([None], schema=schema).collect()[0][0] is None

assert (
session.create_dataframe(
[None for _ in range(ARRAY_BIND_THRESHOLD + 1)], schema=schema
).collect()[0][0]
is None
)

schema = StructType([StructField("key", LongType(), nullable=False)])
with pytest.raises(IntegrityError, match="NULL result in a non-nullable column"):
session.create_dataframe([None for _ in range(10)], schema=schema).collect()

with pytest.raises(IntegrityError, match="NULL result in a non-nullable column"):
session.create_dataframe(
[None for _ in range(ARRAY_BIND_THRESHOLD + 1)], schema=schema
).collect()
13 changes: 8 additions & 5 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,9 +1827,12 @@ def test_create_dataframe_with_schema_col_names(session):
for field, expected_name in zip(df.schema.fields, col_names[:2] + ["_3", "_4"]):
assert Utils.equals_ignore_case(field.name, expected_name)

# specify nullable in structtype to avoid insert null value into non-nullable column
struct_col_name = StructType([StructField(col, StringType()) for col in col_names])

# the column names provided via schema keyword will overwrite other column names
df = session.create_dataframe(
[{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=col_names
[{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=struct_col_name
)
for field, expected_name in zip(df.schema.fields, col_names):
assert Utils.equals_ignore_case(field.name, expected_name)
Expand Down Expand Up @@ -2734,15 +2737,15 @@ def test_save_as_table_nullable_test(
StructField("B", data_type, True),
]
)
df = session.create_dataframe(
[(None, None)] * (5000 if large_data else 1), schema=schema
)

try:
with pytest.raises(
(IntegrityError, SnowparkSQLException),
match="NULL result in a non-nullable column",
):
df = session.create_dataframe(
[(None, None)] * (5000 if large_data else 1), schema=schema
)
df.write.save_as_table(table_name, mode=save_mode)
finally:
Utils.drop_table(session, table_name)
Expand All @@ -2768,13 +2771,13 @@ def mock_run_query(*args, **kwargs):
StructField("B", IntegerType(), True),
]
)
df = session.create_dataframe([(None, None)], schema=schema)

try:
with pytest.raises(
(IntegrityError, SnowparkSQLException),
match="NULL result in a non-nullable column",
):
df = session.create_dataframe([(None, None)], schema=schema)
df.write.save_as_table(table_name, mode=save_mode)
finally:
Utils.drop_table(session, table_name)
Expand Down
Loading