diff --git a/CHANGELOG.md b/CHANGELOG.md index 18d2d7a1257..ae2cacdc3d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,7 @@ - Fixed a bug in local testing mode that caused a column to contain None when it should contain 0 - Fixed a bug in `StructField.from_json` that prevented TimestampTypes with `tzinfo` from being parsed correctly. - Fixed a bug in function `date_format` that caused an error when the input column was date type or timestamp type. +- Fixed a bug in dataframe that null value can be inserted in a non-nullable column. - Fixed a bug in `replace` when passing `Column` expression objects. ### Snowpark pandas API Updates diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 99b63cf61f6..7aadd7ac2fc 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -6,6 +6,8 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union +from snowflake.connector import IntegrityError + import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( alias_expression, @@ -975,6 +977,8 @@ def do_resolve_with_resolved_children( if logical_plan.data: if not logical_plan.is_large_local_data: + if logical_plan.is_contain_illegal_null_value: + raise IntegrityError("NULL result in a non-nullable column") return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 7d8fcedabab..e9f1c0575ca 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -158,6 +158,20 @@ def is_large_local_data(self) -> bool: return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property + def is_contain_illegal_null_value(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + rows_to_compare = min( + ARRAY_BIND_THRESHOLD // len(self.output) + 1, len(self.data) + ) + for j in range(len(self.output)): + if not self.output[j].nullable: + for i in range(rows_to_compare): + if self.data[i][j] is None: + return True + return False + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self.is_large_local_data: diff --git a/tests/integ/compiler/test_query_generator.py b/tests/integ/compiler/test_query_generator.py index bbcb2d0f36d..8902f66823d 100644 --- a/tests/integ/compiler/test_query_generator.py +++ b/tests/integ/compiler/test_query_generator.py @@ -6,6 +6,7 @@ from typing import List import pytest +from snowflake.connector import IntegrityError from snowflake.snowpark import Window from snowflake.snowpark._internal.analyzer import analyzer @@ -51,6 +52,7 @@ random_name_for_temp_object, ) from snowflake.snowpark.functions import avg, col, lit, when_matched +from snowflake.snowpark.types import StructType, StructField, LongType from tests.integ.scala.test_dataframe_reader_suite import get_reader from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker from tests.utils import TestFiles, Utils @@ -533,3 +535,26 @@ def test_select_alias(session): # Add a new column d that doesn't use c after c was added previously. Flatten safely. df2 = df1.select("a", "b", "c", (col("a") + col("b") + 1).as_("d")) check_generated_plan_queries(df2._plan) + + +def test_nullable_is_false_dataframe(session): + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + schema = StructType([StructField("key", LongType(), nullable=True)]) + assert session.create_dataframe([None], schema=schema).collect()[0][0] is None + + assert ( + session.create_dataframe( + [None for _ in range(ARRAY_BIND_THRESHOLD + 1)], schema=schema + ).collect()[0][0] + is None + ) + + schema = StructType([StructField("key", LongType(), nullable=False)]) + with pytest.raises(IntegrityError, match="NULL result in a non-nullable column"): + session.create_dataframe([None for _ in range(10)], schema=schema).collect() + + with pytest.raises(IntegrityError, match="NULL result in a non-nullable column"): + session.create_dataframe( + [None for _ in range(ARRAY_BIND_THRESHOLD + 1)], schema=schema + ).collect() diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index c3abf6f58dc..1e689f4cc2c 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -1827,9 +1827,12 @@ def test_create_dataframe_with_schema_col_names(session): for field, expected_name in zip(df.schema.fields, col_names[:2] + ["_3", "_4"]): assert Utils.equals_ignore_case(field.name, expected_name) + # specify nullable in structtype to avoid insert null value into non-nullable column + struct_col_name = StructType([StructField(col, StringType()) for col in col_names]) + # the column names provided via schema keyword will overwrite other column names df = session.create_dataframe( - [{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=col_names + [{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=struct_col_name ) for field, expected_name in zip(df.schema.fields, col_names): assert Utils.equals_ignore_case(field.name, expected_name) @@ -2734,15 +2737,15 @@ def test_save_as_table_nullable_test( StructField("B", data_type, True), ] ) - df = session.create_dataframe( - [(None, None)] * (5000 if large_data else 1), schema=schema - ) try: with pytest.raises( (IntegrityError, SnowparkSQLException), match="NULL result in a non-nullable column", ): + df = session.create_dataframe( + [(None, None)] * (5000 if large_data else 1), schema=schema + ) df.write.save_as_table(table_name, mode=save_mode) finally: Utils.drop_table(session, table_name) @@ -2768,13 +2771,13 @@ def mock_run_query(*args, **kwargs): StructField("B", IntegerType(), True), ] ) - df = session.create_dataframe([(None, None)], schema=schema) try: with pytest.raises( (IntegrityError, SnowparkSQLException), match="NULL result in a non-nullable column", ): + df = session.create_dataframe([(None, None)], schema=schema) df.write.save_as_table(table_name, mode=save_mode) finally: Utils.drop_table(session, table_name)