From 9c45bc48779285d1f26c5d7bb715f3acad87e4e4 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 26 Jul 2024 10:17:32 -0700 Subject: [PATCH 1/2] fix merge --- src/main/scala/com/snowflake/snowpark/Updatable.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/main/scala/com/snowflake/snowpark/Updatable.scala b/src/main/scala/com/snowflake/snowpark/Updatable.scala index f7dbac81..6638510d 100644 --- a/src/main/scala/com/snowflake/snowpark/Updatable.scala +++ b/src/main/scala/com/snowflake/snowpark/Updatable.scala @@ -11,6 +11,13 @@ private[snowpark] object Updatable extends Logging { private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult = UpdateResult(rows.head.getLong(0), rows.head.getLong(1)) +// UpdateResult(rows.head.getLong(0), +// if (rows.head.length == 1) { +// 0 +// } else { +// rows.head.getLong(1) +// }) + private[snowpark] def getDeleteResult(rows: Array[Row]): DeleteResult = DeleteResult(rows.head.getLong(0)) From 24788d84ece69bd96560bc17cfaf220a36026f4c Mon Sep 17 00:00:00 2001 From: Bing Li Date: Fri, 26 Jul 2024 10:32:25 -0700 Subject: [PATCH 2/2] fix update result --- .../com/snowflake/snowpark/Updatable.scala | 13 ++++------- .../snowpark_test/JavaUpdatableSuite.java | 23 +++++++++++++++++++ .../com/snowflake/snowpark/SNTestBase.scala | 3 ++- .../snowpark_test/UpdatableSuite.scala | 10 ++++++++ 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/Updatable.scala b/src/main/scala/com/snowflake/snowpark/Updatable.scala index 6638510d..be31a38d 100644 --- a/src/main/scala/com/snowflake/snowpark/Updatable.scala +++ b/src/main/scala/com/snowflake/snowpark/Updatable.scala @@ -10,14 +10,11 @@ private[snowpark] object Updatable extends Logging { new Updatable(tableName, session, DataFrame.methodChainCache.value) private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult = - UpdateResult(rows.head.getLong(0), rows.head.getLong(1)) -// UpdateResult(rows.head.getLong(0), -// if (rows.head.length == 1) { -// 0 -// } else { -// rows.head.getLong(1) -// }) - + UpdateResult(rows.head.getLong(0), if (rows.head.length == 1) { + 0 + } else { + rows.head.getLong(1) + }) private[snowpark] def getDeleteResult(rows: Array[Row]): DeleteResult = DeleteResult(rows.head.getLong(0)) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java index 255cc880..39089fce 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java @@ -350,6 +350,29 @@ public void cloneTest() { } } + @Test + public void singleColumnUpdateResult() { + Map params = new HashMap<>(); + params.put("ERROR_ON_NONDETERMINISTIC_UPDATE", "true"); + withSessionParameters( + params, + getSession(), + false, + () -> { + String tableName = randomName(); + try { + createTestTable(tableName); + Map map = new HashMap<>(); + map.put(Functions.col("col1"), Functions.lit(3)); + UpdateResult result = getSession().table(tableName).update(map); + assert result.getRowsUpdated() == 2; + assert result.getMultiJoinedRowsUpdated() == 0; + } finally { + dropTable(tableName); + } + }); + } + private void createTestTable(String name) { Row[] data = {Row.create(1, "a", true), Row.create(2, "b", false)}; StructType schema = diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala index cba5a70d..71434964 100644 --- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala +++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala @@ -310,7 +310,8 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S ("FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable")), - currentSession, skipPreprod = true)(thunk) + currentSession, + skipPreprod = true)(thunk) // disable these tests on preprod daily tests until these parameters are enabled by default. } } diff --git a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala index 2360866e..302f366a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala @@ -388,4 +388,14 @@ class UpdatableSuite extends TestData { dropTable(tableName) } } + + test("ERROR_ON_NONDETERMINISTIC_UPDATE = true") { + withSessionParameters(Seq(("ERROR_ON_NONDETERMINISTIC_UPDATE", "true")), session) { + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + val updatable = session.table(tableName) + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + assert(updatable.update(Map(col("a") -> lit(1), col("b") -> lit(0))) == UpdateResult(6, 0)) + } + } } +