diff --git a/src/main/scala/com/snowflake/snowpark/Updatable.scala b/src/main/scala/com/snowflake/snowpark/Updatable.scala index f7dbac81..be31a38d 100644 --- a/src/main/scala/com/snowflake/snowpark/Updatable.scala +++ b/src/main/scala/com/snowflake/snowpark/Updatable.scala @@ -10,7 +10,11 @@ private[snowpark] object Updatable extends Logging { new Updatable(tableName, session, DataFrame.methodChainCache.value) private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult = - UpdateResult(rows.head.getLong(0), rows.head.getLong(1)) + UpdateResult(rows.head.getLong(0), if (rows.head.length == 1) { + 0 + } else { + rows.head.getLong(1) + }) private[snowpark] def getDeleteResult(rows: Array[Row]): DeleteResult = DeleteResult(rows.head.getLong(0)) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java index 255cc880..39089fce 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java @@ -350,6 +350,29 @@ public void cloneTest() { } } + @Test + public void singleColumnUpdateResult() { + Map params = new HashMap<>(); + params.put("ERROR_ON_NONDETERMINISTIC_UPDATE", "true"); + withSessionParameters( + params, + getSession(), + false, + () -> { + String tableName = randomName(); + try { + createTestTable(tableName); + Map map = new HashMap<>(); + map.put(Functions.col("col1"), Functions.lit(3)); + UpdateResult result = getSession().table(tableName).update(map); + assert result.getRowsUpdated() == 2; + assert result.getMultiJoinedRowsUpdated() == 0; + } finally { + dropTable(tableName); + } + }); + } + private void createTestTable(String name) { Row[] data = {Row.create(1, "a", true), Row.create(2, "b", false)}; StructType schema = diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala index cba5a70d..71434964 100644 --- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala +++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala @@ -310,7 +310,8 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S ("FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable")), - currentSession, skipPreprod = true)(thunk) + currentSession, + skipPreprod = true)(thunk) // disable these tests on preprod daily tests until these parameters are enabled by default. } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala new file mode 100644 index 00000000..66a34b0b --- /dev/null +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala @@ -0,0 +1,85 @@ +package com.snowflake.snowpark_test + +import com.snowflake.snowpark.TestData +import com.snowflake.snowpark.functions.col + +class DataFrameNonStoredProcSuite extends TestData { + + private def testDataframeStatPivot(): Unit = { + assert( + getShowString(monthlySales.stat.crosstab("empid", "month").sort(col("empid")), 10) == + """--------------------------------------------------- + ||"EMPID" |"'JAN'" |"'FEB'" |"'MAR'" |"'APR'" | + |--------------------------------------------------- + ||1 |2 |2 |2 |2 | + ||2 |2 |2 |2 |2 | + |--------------------------------------------------- + |""".stripMargin) + + assert( + getShowString(monthlySales.stat.crosstab("month", "empid").sort(col("month")), 10) == + """------------------------------------------------------------------- + ||"MONTH" |"CAST(1 AS NUMBER(38,0))" |"CAST(2 AS NUMBER(38,0))" | + |------------------------------------------------------------------- + ||APR |2 |2 | + ||FEB |2 |2 | + ||JAN |2 |2 | + ||MAR |2 |2 | + |------------------------------------------------------------------- + |""".stripMargin) + + assert( + getShowString(date1.stat.crosstab("a", "b").sort(col("a")), 10) == + """---------------------------------------------------------------------- + ||"A" |"CAST(1 AS NUMBER(38,0))" |"CAST(2 AS NUMBER(38,0))" | + |---------------------------------------------------------------------- + ||2010-12-01 |0 |1 | + ||2020-08-01 |1 |0 | + |---------------------------------------------------------------------- + |""".stripMargin) + + assert( + getShowString(date1.stat.crosstab("b", "a").sort(col("b")), 10) == + """----------------------------------------------------------- + ||"B" |"TO_DATE('2020-08-01')" |"TO_DATE('2010-12-01')" | + |----------------------------------------------------------- + ||1 |1 |0 | + ||2 |0 |1 | + |----------------------------------------------------------- + |""".stripMargin) + + assert( + getShowString(string7.stat.crosstab("a", "b").sort(col("a")), 10) == + """---------------------------------------------------------------- + ||"A" |"CAST(1 AS NUMBER(38,0))" |"CAST(2 AS NUMBER(38,0))" | + |---------------------------------------------------------------- + ||NULL |0 |1 | + ||str |1 |0 | + |---------------------------------------------------------------- + |""".stripMargin) + + assert( + getShowString(string7.stat.crosstab("b", "a").sort(col("b")), 10) == + """-------------------------- + ||"B" |"'str'" |"NULL" | + |-------------------------- + ||1 |1 |0 | + ||2 |0 |0 | + |-------------------------- + |""".stripMargin) + } + + test("df.stat.pivot") { + testWithAlteredSessionParameter( + testDataframeStatPivot(), + "ENABLE_PIVOT_VIEW_WITH_OBJECT_AGG", + "disable", + skipIfParamNotExist = true) + + testWithAlteredSessionParameter( + testDataframeStatPivot(), + "ENABLE_PIVOT_VIEW_WITH_OBJECT_AGG", + "enable", + skipIfParamNotExist = true) + } +} diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala index 45aec853..31733928 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala @@ -469,70 +469,6 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert(double2.stat.approxQuantile(Array[String](), Array[Double]()).isEmpty) } - test("df.stat.pivot") { - assert( - getShowString(monthlySales.stat.crosstab("empid", "month"), 10) == - """--------------------------------------------------- - ||"EMPID" |"'JAN'" |"'FEB'" |"'MAR'" |"'APR'" | - |--------------------------------------------------- - ||1 |2 |2 |2 |2 | - ||2 |2 |2 |2 |2 | - |--------------------------------------------------- - |""".stripMargin) - - assert( - getShowString(monthlySales.stat.crosstab("month", "empid"), 10) == - """------------------------------------------------------------------- - ||"MONTH" |"CAST(1 AS NUMBER(38,0))" |"CAST(2 AS NUMBER(38,0))" | - |------------------------------------------------------------------- - ||JAN |2 |2 | - ||FEB |2 |2 | - ||MAR |2 |2 | - ||APR |2 |2 | - |------------------------------------------------------------------- - |""".stripMargin) - - assert( - getShowString(date1.stat.crosstab("a", "b"), 10) == - """---------------------------------------------------------------------- - ||"A" |"CAST(1 AS NUMBER(38,0))" |"CAST(2 AS NUMBER(38,0))" | - |---------------------------------------------------------------------- - ||2020-08-01 |1 |0 | - ||2010-12-01 |0 |1 | - |---------------------------------------------------------------------- - |""".stripMargin) - - assert( - getShowString(date1.stat.crosstab("b", "a"), 10) == - """----------------------------------------------------------- - ||"B" |"TO_DATE('2020-08-01')" |"TO_DATE('2010-12-01')" | - |----------------------------------------------------------- - ||1 |1 |0 | - ||2 |0 |1 | - |----------------------------------------------------------- - |""".stripMargin) - - assert( - getShowString(string7.stat.crosstab("a", "b"), 10) == - """---------------------------------------------------------------- - ||"A" |"CAST(1 AS NUMBER(38,0))" |"CAST(2 AS NUMBER(38,0))" | - |---------------------------------------------------------------- - ||str |1 |0 | - ||NULL |0 |1 | - |---------------------------------------------------------------- - |""".stripMargin) - - assert( - getShowString(string7.stat.crosstab("b", "a"), 10) == - """-------------------------- - ||"B" |"'str'" |"NULL" | - |-------------------------- - ||1 |1 |0 | - ||2 |0 |0 | - |-------------------------- - |""".stripMargin) - } - test("df.stat.sampleBy") { assert( getShowString(monthlySales.stat.sampleBy(col("empid"), Map(1 -> 0.0, 2 -> 1.0)), 10) == diff --git a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala index 2360866e..302f366a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala @@ -388,4 +388,14 @@ class UpdatableSuite extends TestData { dropTable(tableName) } } + + test("ERROR_ON_NONDETERMINISTIC_UPDATE = true") { + withSessionParameters(Seq(("ERROR_ON_NONDETERMINISTIC_UPDATE", "true")), session) { + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + val updatable = session.table(tableName) + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + assert(updatable.update(Map(col("a") -> lit(1), col("b") -> lit(0))) == UpdateResult(6, 0)) + } + } } +