diff --git a/src/main/scala/com/snowflake/snowpark/Updatable.scala b/src/main/scala/com/snowflake/snowpark/Updatable.scala index f7dbac81..be31a38d 100644 --- a/src/main/scala/com/snowflake/snowpark/Updatable.scala +++ b/src/main/scala/com/snowflake/snowpark/Updatable.scala @@ -10,7 +10,11 @@ private[snowpark] object Updatable extends Logging { new Updatable(tableName, session, DataFrame.methodChainCache.value) private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult = - UpdateResult(rows.head.getLong(0), rows.head.getLong(1)) + UpdateResult(rows.head.getLong(0), if (rows.head.length == 1) { + 0 + } else { + rows.head.getLong(1) + }) private[snowpark] def getDeleteResult(rows: Array[Row]): DeleteResult = DeleteResult(rows.head.getLong(0)) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java index 255cc880..39089fce 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java @@ -350,6 +350,29 @@ public void cloneTest() { } } + @Test + public void singleColumnUpdateResult() { + Map params = new HashMap<>(); + params.put("ERROR_ON_NONDETERMINISTIC_UPDATE", "true"); + withSessionParameters( + params, + getSession(), + false, + () -> { + String tableName = randomName(); + try { + createTestTable(tableName); + Map map = new HashMap<>(); + map.put(Functions.col("col1"), Functions.lit(3)); + UpdateResult result = getSession().table(tableName).update(map); + assert result.getRowsUpdated() == 2; + assert result.getMultiJoinedRowsUpdated() == 0; + } finally { + dropTable(tableName); + } + }); + } + private void createTestTable(String name) { Row[] data = {Row.create(1, "a", true), Row.create(2, "b", false)}; StructType schema = diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala index cba5a70d..71434964 100644 --- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala +++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala @@ -310,7 +310,8 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S ("FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable")), - currentSession, skipPreprod = true)(thunk) + currentSession, + skipPreprod = true)(thunk) // disable these tests on preprod daily tests until these parameters are enabled by default. } } diff --git a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala index 2360866e..302f366a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala @@ -388,4 +388,14 @@ class UpdatableSuite extends TestData { dropTable(tableName) } } + + test("ERROR_ON_NONDETERMINISTIC_UPDATE = true") { + withSessionParameters(Seq(("ERROR_ON_NONDETERMINISTIC_UPDATE", "true")), session) { + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + val updatable = session.table(tableName) + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + assert(updatable.update(Map(col("a") -> lit(1), col("b") -> lit(0))) == UpdateResult(6, 0)) + } + } } +