diff --git a/src/main/scala/com/snowflake/snowpark/Updatable.scala b/src/main/scala/com/snowflake/snowpark/Updatable.scala index f7dbac81..be31a38d 100644 --- a/src/main/scala/com/snowflake/snowpark/Updatable.scala +++ b/src/main/scala/com/snowflake/snowpark/Updatable.scala @@ -10,7 +10,11 @@ private[snowpark] object Updatable extends Logging { new Updatable(tableName, session, DataFrame.methodChainCache.value) private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult = - UpdateResult(rows.head.getLong(0), rows.head.getLong(1)) + UpdateResult(rows.head.getLong(0), if (rows.head.length == 1) { + 0 + } else { + rows.head.getLong(1) + }) private[snowpark] def getDeleteResult(rows: Array[Row]): DeleteResult = DeleteResult(rows.head.getLong(0)) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java index 255cc880..39089fce 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java @@ -350,6 +350,29 @@ public void cloneTest() { } } + @Test + public void singleColumnUpdateResult() { + Map params = new HashMap<>(); + params.put("ERROR_ON_NONDETERMINISTIC_UPDATE", "true"); + withSessionParameters( + params, + getSession(), + false, + () -> { + String tableName = randomName(); + try { + createTestTable(tableName); + Map map = new HashMap<>(); + map.put(Functions.col("col1"), Functions.lit(3)); + UpdateResult result = getSession().table(tableName).update(map); + assert result.getRowsUpdated() == 2; + assert result.getMultiJoinedRowsUpdated() == 0; + } finally { + dropTable(tableName); + } + }); + } + private void createTestTable(String name) { Row[] data = {Row.create(1, "a", true), Row.create(2, "b", false)}; StructType schema = diff --git a/src/test/java/com/snowflake/snowpark_test/TestBase.java b/src/test/java/com/snowflake/snowpark_test/TestBase.java index aa0f5d5e..26a64cd0 100644 --- a/src/test/java/com/snowflake/snowpark_test/TestBase.java +++ b/src/test/java/com/snowflake/snowpark_test/TestBase.java @@ -3,6 +3,8 @@ import com.snowflake.snowpark.TestUtils; import com.snowflake.snowpark_java.JavaToScalaConvertor; import com.snowflake.snowpark_java.Session; +import java.util.Map; +import java.util.Optional; public abstract class TestBase extends TestFunctions { @@ -20,6 +22,34 @@ protected Session getSession() { return _session; } + Optional isPreprodAccount = Optional.empty(); + + protected boolean isPreprodAccount() { + if (!isPreprodAccount.isPresent()) { + isPreprodAccount = + Optional.of( + !getSession() + .sql("select current_account()") + .collect()[0] + .getString(0) + .contains("SFCTEST0")); + } + return isPreprodAccount.get(); + } + + protected void withSessionParameters( + Map params, Session session, boolean skipPreprod, TestMethod thunk) { + if (!(skipPreprod && isPreprodAccount())) { + try { + params.forEach( + (name, value) -> session.sql("alter session set " + name + " = " + value).collect()); + thunk.run(); + } finally { + params.forEach((name, value) -> session.sql("alter session unset " + name).collect()); + } + } + } + protected void runQuery(String sql) { getSession().sql(sql).collect(); } diff --git a/src/test/java/com/snowflake/snowpark_test/TestMethod.java b/src/test/java/com/snowflake/snowpark_test/TestMethod.java new file mode 100644 index 00000000..265991f9 --- /dev/null +++ b/src/test/java/com/snowflake/snowpark_test/TestMethod.java @@ -0,0 +1,6 @@ +package com.snowflake.snowpark_test; + +@FunctionalInterface +public interface TestMethod { + void run(); +} diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala index 3fd98755..c09a9c9f 100644 --- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala +++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala @@ -279,6 +279,41 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S } } } + + lazy val isPreprodAccount: Boolean = + !session.sql("select current_account()").collect().head.getString(0).contains("SFCTEST0") + + def withSessionParameters( + params: Seq[(String, String)], + currentSession: Session, + skipPreprod: Boolean = false)(thunk: => Unit): Unit = { + if (!(skipPreprod && isPreprodAccount)) { + try { + params.foreach { + case (paramName, value) => + runQuery(s"alter session set $paramName = $value", currentSession) + } + thunk + } finally { + params.foreach { + case (paramName, _) => runQuery(s"alter session unset $paramName", currentSession) + } + } + } + } + + def structuredTypeTest(thunk: => Unit)(implicit currentSession: Session): Unit = { + withSessionParameters( + Seq( + ("ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE", "true"), + ("IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE", "true"), + ("FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), + ("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), + ("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable")), + currentSession, + skipPreprod = true)(thunk) + // disable these tests on preprod daily tests until these parameters are enabled by default. + } } trait SnowTestFiles { diff --git a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala index 2360866e..302f366a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala @@ -388,4 +388,14 @@ class UpdatableSuite extends TestData { dropTable(tableName) } } + + test("ERROR_ON_NONDETERMINISTIC_UPDATE = true") { + withSessionParameters(Seq(("ERROR_ON_NONDETERMINISTIC_UPDATE", "true")), session) { + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + val updatable = session.table(tableName) + testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) + assert(updatable.update(Map(col("a") -> lit(1), col("b") -> lit(0))) == UpdateResult(6, 0)) + } + } } +