Skip to content

Commit

Permalink
SNOW-1551862 Don't Report Error in UpdateResult when ERROR_ON_NONDETE…
Browse files Browse the repository at this point in the history
…RMINISTIC_UPDATE=true (#126)

* fix merge

* fix update result
  • Loading branch information
sfc-gh-bli committed Aug 5, 2024
1 parent 68f88b3 commit 25775bc
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/main/scala/com/snowflake/snowpark/Updatable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ private[snowpark] object Updatable extends Logging {
new Updatable(tableName, session, DataFrame.methodChainCache.value)

private[snowpark] def getUpdateResult(rows: Array[Row]): UpdateResult =
UpdateResult(rows.head.getLong(0), rows.head.getLong(1))
UpdateResult(rows.head.getLong(0), if (rows.head.length == 1) {
0
} else {
rows.head.getLong(1)
})

private[snowpark] def getDeleteResult(rows: Array[Row]): DeleteResult =
DeleteResult(rows.head.getLong(0))
Expand Down
23 changes: 23 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaUpdatableSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,29 @@ public void cloneTest() {
}
}

@Test
public void singleColumnUpdateResult() {
Map<String, String> params = new HashMap<>();
params.put("ERROR_ON_NONDETERMINISTIC_UPDATE", "true");
withSessionParameters(
params,
getSession(),
false,
() -> {
String tableName = randomName();
try {
createTestTable(tableName);
Map<Column, Column> map = new HashMap<>();
map.put(Functions.col("col1"), Functions.lit(3));
UpdateResult result = getSession().table(tableName).update(map);
assert result.getRowsUpdated() == 2;
assert result.getMultiJoinedRowsUpdated() == 0;
} finally {
dropTable(tableName);
}
});
}

private void createTestTable(String name) {
Row[] data = {Row.create(1, "a", true), Row.create(2, "b", false)};
StructType schema =
Expand Down
30 changes: 30 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/TestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.snowflake.snowpark.TestUtils;
import com.snowflake.snowpark_java.JavaToScalaConvertor;
import com.snowflake.snowpark_java.Session;
import java.util.Map;
import java.util.Optional;

public abstract class TestBase extends TestFunctions {

Expand All @@ -20,6 +22,34 @@ protected Session getSession() {
return _session;
}

Optional<Boolean> isPreprodAccount = Optional.empty();

protected boolean isPreprodAccount() {
if (!isPreprodAccount.isPresent()) {
isPreprodAccount =
Optional.of(
!getSession()
.sql("select current_account()")
.collect()[0]
.getString(0)
.contains("SFCTEST0"));
}
return isPreprodAccount.get();
}

protected void withSessionParameters(
Map<String, String> params, Session session, boolean skipPreprod, TestMethod thunk) {
if (!(skipPreprod && isPreprodAccount())) {
try {
params.forEach(
(name, value) -> session.sql("alter session set " + name + " = " + value).collect());
thunk.run();
} finally {
params.forEach((name, value) -> session.sql("alter session unset " + name).collect());
}
}
}

protected void runQuery(String sql) {
getSession().sql(sql).collect();
}
Expand Down
6 changes: 6 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/TestMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.snowflake.snowpark_test;

@FunctionalInterface
public interface TestMethod {
void run();
}
35 changes: 35 additions & 0 deletions src/test/scala/com/snowflake/snowpark/SNTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,41 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S
}
}
}

lazy val isPreprodAccount: Boolean =
!session.sql("select current_account()").collect().head.getString(0).contains("SFCTEST0")

def withSessionParameters(
params: Seq[(String, String)],
currentSession: Session,
skipPreprod: Boolean = false)(thunk: => Unit): Unit = {
if (!(skipPreprod && isPreprodAccount)) {
try {
params.foreach {
case (paramName, value) =>
runQuery(s"alter session set $paramName = $value", currentSession)
}
thunk
} finally {
params.foreach {
case (paramName, _) => runQuery(s"alter session unset $paramName", currentSession)
}
}
}
}

def structuredTypeTest(thunk: => Unit)(implicit currentSession: Session): Unit = {
withSessionParameters(
Seq(
("ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE", "true"),
("IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE", "true"),
("FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"),
("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"),
("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable")),
currentSession,
skipPreprod = true)(thunk)
// disable these tests on preprod daily tests until these parameters are enabled by default.
}
}

trait SnowTestFiles {
Expand Down
10 changes: 10 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -388,4 +388,14 @@ class UpdatableSuite extends TestData {
dropTable(tableName)
}
}

test("ERROR_ON_NONDETERMINISTIC_UPDATE = true") {
withSessionParameters(Seq(("ERROR_ON_NONDETERMINISTIC_UPDATE", "true")), session) {
testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName)
val updatable = session.table(tableName)
testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName)
assert(updatable.update(Map(col("a") -> lit(1), col("b") -> lit(0))) == UpdateResult(6, 0))
}
}
}

0 comments on commit 25775bc

Please sign in to comment.