diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index 443518c6..192d9705 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -79,8 +79,7 @@ public DataFrame sql(String query) { * @since 0.8.0 */ public DataFrame sql(String query, Object... params) { - return new DataFrame( - session.sql(query, JavaUtils.objectArrayToSeq(params))); + return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(params))); } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index e855afb3..5fd3feb2 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -288,8 +288,7 @@ private[snowpark] class ServerConnection( true, false, getStatementParameters(isDDLOnTempObject = false, Map.empty), - Seq.empty - ).rows.get + Seq.empty).rows.get .map(r => r.getString(0).toLowerCase() + PackageNameDelimiter + r.getString(1).toLowerCase()) .toSet @@ -298,7 +297,7 @@ private[snowpark] class ServerConnection( statement: PreparedStatement, params: Seq[Any]): Unit = params.zipWithIndex.foreach { - case (p, i) => statement.setObject(i + 1, p) + case (p, i) => statement.setObject(i + 1, p) } private[snowflake] def setStatementParameters( @@ -883,7 +882,8 @@ private[snowpark] class ServerConnection( val statement = connection.prepareStatement(multipleStatements) try { // Note binding parameters only supported for single query - val bindingParameters = if (plan.queries.length == 1) plan.queries.last.params else Seq() + val bindingParameters = + if (plan.queries.length == 1) plan.queries.last.params else Seq() val statementParameters = getStatementParameters() + ("MULTI_STATEMENT_COUNT" -> plan.queries.size) setBindingParameters(statement, bindingParameters) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index 0c12b8f5..0a32e0a4 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -292,8 +292,7 @@ class SnowflakePlanBuilder(session: Session) extends Logging { selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++ selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query( sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql), - lastQueryLeft.params ++ lastQueryRight.params - ) + lastQueryLeft.params ++ lastQueryRight.params) val leftSchemaQuery = schemaValueStatement(selectLeft.attributes) val rightSchemaQuery = schemaValueStatement(selectRight.attributes) val schemaQuery = sqlGenerator(leftSchemaQuery, rightSchemaQuery) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index 4d192738..3fb4a626 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -7,8 +7,6 @@ import com.snowflake.snowpark_java.types.StructField; import com.snowflake.snowpark_java.types.StructType; import java.sql.Connection; -import java.util.Arrays; - import net.snowflake.client.jdbc.SnowflakeConnection; import org.junit.Test; @@ -67,13 +65,11 @@ public void generator() { public void sql() { checkAnswer( getSession().sql("select * from values(1, 2),(3, 4) as t(a, b)"), - new Row[] {Row.create(1, 2), Row.create(3, 4)} - ); + new Row[] {Row.create(1, 2), Row.create(3, 4)}); checkAnswer( getSession().sql("select * from values(?, ?),(?, ?) as t(a, b)", 1, 2, 3, 4), - new Row[] {Row.create(1, 2), Row.create(3, 4)} - ); + new Row[] {Row.create(1, 2), Row.create(3, 4)}); } @Test diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala index 1817a078..d0478606 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala @@ -105,7 +105,6 @@ trait SqlSuite extends SNTestBase { } } - test("Run sql query with bindings") { val df1 = session.sql("select * from values (?),(?),(?)", List(1, 2, 3)) assert(df1.collect() sameElements Array[Row](Row(1), Row(2), Row(3))) @@ -121,10 +120,10 @@ trait SqlSuite extends SNTestBase { .filter(col("id") < 3) assert(df3.collect() sameElements Array[Row](Row(1), Row(2))) - val df4 = session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1)) - val df5 = session.sql( - "select * from values (?,?),(?,?),(?,?) as T(a, b)", - List(1, 2, 2, 1, 4, 3)) + val df4 = + session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1)) + val df5 = + session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", List(1, 2, 2, 1, 4, 3)) val df6 = df4.union(df5).filter(col("a") < 3) assert(df6.collect() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) @@ -133,7 +132,8 @@ trait SqlSuite extends SNTestBase { // Async result assert(df1.async.collect().getResult() sameElements Array[Row](Row(1), Row(2), Row(3))) - assert(df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) + assert( + df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) } }