From e3103999f3f9598bbeaf1623fa8ac1bbe5a1015c Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 9 Jan 2025 16:48:05 -0800 Subject: [PATCH 1/7] Fixed a bug in dataframe that null value can be inserted in a non-nullable column --- CHANGELOG.md | 1 + .../snowpark/_internal/analyzer/analyzer.py | 5 +++- .../_internal/analyzer/snowflake_plan_node.py | 7 ++++++ tests/integ/compiler/test_query_generator.py | 25 +++++++++++++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ebf01e80d1..5d55aaa19b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ - Fixed a bug in local testing mode that caused a column to contain None when it should contain 0 - Fixed a bug in StructField.from_json that prevented TimestampTypes with tzinfo from being parsed correctly. +- Fixed a bug in dataframe that null value can be inserted in a non-nullable column. ### Snowpark pandas API Updates diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 328a7d5d0b3..6d25d160903 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -974,7 +974,10 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if not logical_plan.is_large_local_data: + if ( + not logical_plan.is_large_local_data + and logical_plan.is_all_column_nullable + ): return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 017ec433163..b8301749eea 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -158,6 +158,13 @@ def is_large_local_data(self) -> bool: return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD + @property + def is_all_column_nullable(self) -> bool: + for attribute in self.output: + if not attribute.nullable: + return False + return True + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: if self.is_large_local_data: diff --git a/tests/integ/compiler/test_query_generator.py b/tests/integ/compiler/test_query_generator.py index d62fec1e081..14b34910c98 100644 --- a/tests/integ/compiler/test_query_generator.py +++ b/tests/integ/compiler/test_query_generator.py @@ -6,6 +6,7 @@ from typing import List import pytest +from snowflake.connector import IntegrityError from snowflake.snowpark import Window from snowflake.snowpark._internal.analyzer import analyzer @@ -51,6 +52,7 @@ random_name_for_temp_object, ) from snowflake.snowpark.functions import avg, col, lit, when_matched +from snowflake.snowpark.types import StructType, StructField, LongType from tests.integ.scala.test_dataframe_reader_suite import get_reader from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker from tests.utils import TestFiles, Utils @@ -533,3 +535,26 @@ def test_select_alias(session): # Add a new column d that doesn't use c after c was added previously. Flatten safely. df2 = df1.select("a", "b", "c", (col("a") + col("b") + 1).as_("d")) check_generated_plan_queries(df2._plan) + + +def test_nullable_is_false_dataframe(session): + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + schema = StructType([StructField("key", LongType(), nullable=True)]) + assert session.create_dataframe([None], schema=schema).collect()[0][0] is None + + assert ( + session.create_dataframe( + [None for _ in range(ARRAY_BIND_THRESHOLD + 1)], schema=schema + ).collect()[0][0] + is None + ) + + schema = StructType([StructField("key", LongType(), nullable=False)]) + with pytest.raises(IntegrityError, match="NULL result in a non-nullable column"): + session.create_dataframe([None for _ in range(10)], schema=schema).collect() + + with pytest.raises(IntegrityError, match="NULL result in a non-nullable column"): + session.create_dataframe( + [None for _ in range(ARRAY_BIND_THRESHOLD + 1)], schema=schema + ).collect() From 8c189baa037b08f93981e76f1873a315b25b2ce7 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 10 Jan 2025 10:22:01 -0800 Subject: [PATCH 2/7] remove actual sql execution --- .../snowpark/_internal/analyzer/analyzer.py | 9 +++++---- .../_internal/analyzer/snowflake_plan_node.py | 16 +++++++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index 6d25d160903..b02c4136ae5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -6,6 +6,8 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union +from snowflake.connector import IntegrityError + import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( alias_expression, @@ -974,10 +976,9 @@ def do_resolve_with_resolved_children( schema_query = schema_query_for_values_statement(logical_plan.output) if logical_plan.data: - if ( - not logical_plan.is_large_local_data - and logical_plan.is_all_column_nullable - ): + if not logical_plan.is_large_local_data: + if logical_plan.is_contain_illegal_null_value: + raise IntegrityError("NULL result in a non-nullable column") return self.plan_builder.query( values_statement(logical_plan.output, logical_plan.data), logical_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index b8301749eea..752f51ab642 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -159,11 +159,17 @@ def is_large_local_data(self) -> bool: return len(self.data) * len(self.output) >= ARRAY_BIND_THRESHOLD @property - def is_all_column_nullable(self) -> bool: - for attribute in self.output: - if not attribute.nullable: - return False - return True + def is_contain_illegal_null_value(self) -> bool: + from snowflake.snowpark._internal.analyzer.analyzer import ARRAY_BIND_THRESHOLD + + rows_to_compare = min( + ARRAY_BIND_THRESHOLD // len(self.output) + 1, len(self.data) + ) + for i in range(rows_to_compare): + for j in range(len(self.output)): + if self.data[i][j] is None and not self.output[j].nullable: + return True + return False @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: From 206cfd50089246dbcfe266ee981d5a57ed2f220e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 10 Jan 2025 11:12:58 -0800 Subject: [PATCH 3/7] test --- src/snowflake/snowpark/_internal/type_utils.py | 2 +- tests/integ/test_dataframe.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index a989e1625f6..3376a6d658f 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -484,7 +484,7 @@ def infer_schema( fields = [] for k, v in items: try: - fields.append(StructField(k, infer_type(v), v is None)) + fields.append(StructField(k, infer_type(v))) except TypeError as e: raise TypeError(f"Unable to infer the type of the field {k}.") from e return StructType(fields) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 7c91222181b..c71027d45b2 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -2734,15 +2734,15 @@ def test_save_as_table_nullable_test( StructField("B", data_type, True), ] ) - df = session.create_dataframe( - [(None, None)] * (5000 if large_data else 1), schema=schema - ) try: with pytest.raises( (IntegrityError, SnowparkSQLException), match="NULL result in a non-nullable column", ): + df = session.create_dataframe( + [(None, None)] * (5000 if large_data else 1), schema=schema + ) df.write.save_as_table(table_name, mode=save_mode) finally: Utils.drop_table(session, table_name) From a56e8e3b61cc0dfd4a8cca22a5f6164d5297bb40 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 10 Jan 2025 13:18:37 -0800 Subject: [PATCH 4/7] test --- src/snowflake/snowpark/_internal/type_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 3376a6d658f..a989e1625f6 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -484,7 +484,7 @@ def infer_schema( fields = [] for k, v in items: try: - fields.append(StructField(k, infer_type(v))) + fields.append(StructField(k, infer_type(v), v is None)) except TypeError as e: raise TypeError(f"Unable to infer the type of the field {k}.") from e return StructType(fields) From 9906e09f81105720c2eb6af9300a773c1310349e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 10 Jan 2025 13:29:42 -0800 Subject: [PATCH 5/7] fix test --- tests/integ/test_dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index c71027d45b2..48bd65192c3 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -1828,8 +1828,9 @@ def test_create_dataframe_with_schema_col_names(session): assert Utils.equals_ignore_case(field.name, expected_name) # the column names provided via schema keyword will overwrite other column names + struct_col_name = StructType([StructField(col, StringType()) for col in col_names]) df = session.create_dataframe( - [{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=col_names + [{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=struct_col_name ) for field, expected_name in zip(df.schema.fields, col_names): assert Utils.equals_ignore_case(field.name, expected_name) From 9c62a495bb21b4b58592f1c9532cf9011907826f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 10 Jan 2025 13:55:34 -0800 Subject: [PATCH 6/7] fix test --- tests/integ/test_dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 48bd65192c3..b924c45949b 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -2769,13 +2769,13 @@ def mock_run_query(*args, **kwargs): StructField("B", IntegerType(), True), ] ) - df = session.create_dataframe([(None, None)], schema=schema) try: with pytest.raises( (IntegrityError, SnowparkSQLException), match="NULL result in a non-nullable column", ): + df = session.create_dataframe([(None, None)], schema=schema) df.write.save_as_table(table_name, mode=save_mode) finally: Utils.drop_table(session, table_name) From 827e093b4bb7adaf2e0c3a52e0b8739fd897b97e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 14 Jan 2025 14:49:59 -0800 Subject: [PATCH 7/7] address comments --- CHANGELOG.md | 1 - .../snowpark/_internal/analyzer/snowflake_plan_node.py | 9 +++++---- tests/integ/test_dataframe.py | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a04ab68e34..100b9e5c9b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,6 @@ #### Bug Fixes - Fixed a bug in local testing mode that caused a column to contain None when it should contain 0 -- Fixed a bug in StructField.from_json that prevented TimestampTypes with tzinfo from being parsed correctly. - Fixed a bug in `StructField.from_json` that prevented TimestampTypes with tzinfo from being parsed correctly. - Fixed a bug in function `date_format` that caused an error when the input column was date type or timestamp type. - Fixed a bug in dataframe that null value can be inserted in a non-nullable column. diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index a5cee66d319..e9f1c0575ca 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -165,10 +165,11 @@ def is_contain_illegal_null_value(self) -> bool: rows_to_compare = min( ARRAY_BIND_THRESHOLD // len(self.output) + 1, len(self.data) ) - for i in range(rows_to_compare): - for j in range(len(self.output)): - if self.data[i][j] is None and not self.output[j].nullable: - return True + for j in range(len(self.output)): + if not self.output[j].nullable: + for i in range(rows_to_compare): + if self.data[i][j] is None: + return True return False @property diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 834a872fab1..86858eedf47 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -1827,8 +1827,10 @@ def test_create_dataframe_with_schema_col_names(session): for field, expected_name in zip(df.schema.fields, col_names[:2] + ["_3", "_4"]): assert Utils.equals_ignore_case(field.name, expected_name) - # the column names provided via schema keyword will overwrite other column names + # specify nullable in structtype to avoid insert null value into non-nullable column struct_col_name = StructType([StructField(col, StringType()) for col in col_names]) + + # the column names provided via schema keyword will overwrite other column names df = session.create_dataframe( [{"aa": 1, "bb": 2, "cc": 3, "dd": 4}], schema=struct_col_name )