From e565db02224575cd31e2edbf4559713a173c7a3d Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Wed, 30 Oct 2024 18:15:40 +0000 Subject: [PATCH 1/4] SNOW-1022196 Support binding parameters for snowpark java api --- .../com/snowflake/snowpark_java/Session.java | 17 +++++++- .../com/snowflake/snowpark/Session.scala | 5 ++- .../snowpark/internal/ServerConnection.scala | 38 +++++++++++++----- .../internal/analyzer/SnowflakePlan.scala | 40 +++++++++++++------ .../snowpark_test/JavaSessionSuite.java | 15 +++++++ .../snowpark/ServerConnectionSuite.scala | 18 +++++++++ .../snowflake/snowpark_test/SqlSuite.scala | 32 +++++++++++++++ 7 files changed, 140 insertions(+), 25 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index c2f4ef6d..443518c6 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -65,7 +65,22 @@ public static SessionBuilder builder() { * @since 0.8.0 */ public DataFrame sql(String query) { - return new DataFrame(session.sql(query)); + return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(new Object[0]))); + } + + /** + * Returns a new {@code DataFrame} representing the results of a SQL query. + * + *

You can use this method to execute an arbitrary SQL statement. + * + * @param query The SQL statement to execute. + * @param params The binding parameters for SQL statement (optional) + * @return A {@code DataFrame} object + * @since 0.8.0 + */ + public DataFrame sql(String query, Object... params) { + return new DataFrame( + session.sql(query, JavaUtils.objectArrayToSeq(params))); } /** diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 1beac1c0..8a670b90 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -945,12 +945,13 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log * You can use this method to execute an arbitrary SQL statement. * * @param query The SQL statement to execute. + * @param params for bind variables in SQL statement. * @return A [[DataFrame]] * @since 0.1.0 */ - def sql(query: String): DataFrame = { + def sql(query: String, params: Seq[Any] = Seq.empty): DataFrame = { // PUT and GET command cannot be executed in async mode - DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query))) + DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query), params)) } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index a2281925..e855afb3 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -30,6 +30,7 @@ import net.snowflake.client.jdbc.{ SnowflakeBaseResultSet, SnowflakeConnectString, SnowflakeConnectionV1, + SnowflakePreparedStatement, SnowflakeReauthenticationRequest, SnowflakeResultSet, SnowflakeResultSetMetaData, @@ -286,11 +287,20 @@ private[snowpark] class ServerConnection( s"where language = 'java'", true, false, - getStatementParameters(isDDLOnTempObject = false, Map.empty)).rows.get + getStatementParameters(isDDLOnTempObject = false, Map.empty), + Seq.empty + ).rows.get .map(r => r.getString(0).toLowerCase() + PackageNameDelimiter + r.getString(1).toLowerCase()) .toSet + private[snowflake] def setBindingParameters( + statement: PreparedStatement, + params: Seq[Any]): Unit = + params.zipWithIndex.foreach { + case (p, i) => statement.setObject(i + 1, p) + } + private[snowflake] def setStatementParameters( statement: Statement, parameters: Map[String, Any]): Unit = @@ -438,12 +448,14 @@ private[snowpark] class ServerConnection( def runQuery( query: String, isDDLOnTempObject: Boolean = false, - statementParameters: Map[String, Any] = Map.empty): String = + statementParameters: Map[String, Any] = Map.empty, + params: Seq[Any] = Seq.empty): String = runQueryGetResult( query, returnRows = false, returnIterator = false, - getStatementParameters(isDDLOnTempObject, statementParameters)).queryId + getStatementParameters(isDDLOnTempObject, statementParameters), + params).queryId // Run the query and return the queryID when the caller doesn't need the ResultSet def runQueryGetRows( @@ -453,7 +465,8 @@ private[snowpark] class ServerConnection( query, returnRows = true, returnIterator = false, - getStatementParameters(isDDLOnTempObject = false, statementParameters)).rows.get + getStatementParameters(isDDLOnTempObject = false, statementParameters), + Seq.empty).rows.get // Run the query to get query result. // 1. If the caller needs to get Iterator[Row], the internal JDBC ResultSet and Statement @@ -466,11 +479,13 @@ private[snowpark] class ServerConnection( query: String, returnRows: Boolean, returnIterator: Boolean, - statementParameters: Map[String, Any]): QueryResult = + statementParameters: Map[String, Any], + params: Seq[Any]): QueryResult = withValidConnection { var statement: PreparedStatement = null try { statement = connection.prepareStatement(query) + setBindingParameters(statement, params) setStatementParameters(statement, statementParameters) val rs = statement.executeQuery() val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID @@ -857,20 +872,23 @@ private[snowpark] class ServerConnection( logDebug(s"""execute plan in async mode: |----------SNOW----------- + | |$plan |------------------------- |""".stripMargin) // use try finally to ensure postActions is always run - val statement = connection.createStatement() + val queries = plan.queries.map(_.sql) + val multipleStatements = queries.mkString("; ") + val statement = connection.prepareStatement(multipleStatements) try { - val queries = plan.queries.map(_.sql) - val multipleStatements = queries.mkString("; ") + // Note binding parameters only supported for single query + val bindingParameters = if (plan.queries.length == 1) plan.queries.last.params else Seq() val statementParameters = getStatementParameters() + ("MULTI_STATEMENT_COUNT" -> plan.queries.size) + setBindingParameters(statement, bindingParameters) setStatementParameters(statement, statementParameters) - val rs = - statement.asInstanceOf[SnowflakeStatement].executeAsyncQuery(multipleStatements) + val rs = statement.asInstanceOf[SnowflakePreparedStatement].executeAsyncQuery() val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID if (actionID <= plan.session.getLastCanceledID) { throw ErrorMessage.MISC_QUERY_IS_CANCELLED() diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index a3218758..0c12b8f5 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -266,8 +266,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging { schemaQuery: Option[String], isDDLOnTempObject: Boolean): SnowflakePlan = wrapException(child) { val selectChild = addResultScanIfNotSelect(child) + val lastQuery = selectChild.queries.last val queries: Seq[Query] = selectChild.queries.slice(0, selectChild.queries.length - 1) ++ - multipleSqlGenerator(selectChild.queries.last.sql).map(Query(_, isDDLOnTempObject)) + multipleSqlGenerator(lastQuery.sql).map(Query(_, isDDLOnTempObject, lastQuery.params)) val newSchemaQuery = schemaQuery.getOrElse(multipleSqlGenerator(child.schemaQuery).last) SnowflakePlan( queries, @@ -284,15 +285,18 @@ class SnowflakePlanBuilder(session: Session) extends Logging { right: SnowflakePlan, sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(left, right) { val selectLeft = addResultScanIfNotSelect(left) + val lastQueryLeft = selectLeft.queries.last val selectRight = addResultScanIfNotSelect(right) + val lastQueryRight = selectRight.queries.last val queries: Seq[Query] = selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++ selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query( - sqlGenerator(selectLeft.queries.last.sql, selectRight.queries.last.sql)) + sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql), + lastQueryLeft.params ++ lastQueryRight.params + ) val leftSchemaQuery = schemaValueStatement(selectLeft.attributes) val rightSchemaQuery = schemaValueStatement(selectRight.attributes) val schemaQuery = sqlGenerator(leftSchemaQuery, rightSchemaQuery) - val commonColumn = selectLeft.aliasMap.keySet.intersect(selectRight.aliasMap.keySet) val supportAsyncMode = selectLeft.supportAsyncMode && selectRight.supportAsyncMode SnowflakePlan( queries, @@ -308,10 +312,11 @@ class SnowflakePlanBuilder(session: Session) extends Logging { children: Seq[SnowflakePlan], sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(children: _*) { val selectChildren = children.map(addResultScanIfNotSelect) + val params: Seq[Any] = selectChildren.map(_.queries.last.params).flatten val queries: Seq[Query] = selectChildren .map(c => c.queries.slice(0, c.queries.length - 1)) - .reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql))) + .reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql)), params) val schemaQueries = children.map(c => schemaValueStatement(c.attributes)) val schemaQuery = sqlGenerator(schemaQueries) @@ -323,8 +328,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging { def query( sql: String, sourcePlan: Option[LogicalPlan], - supportAsyncMode: Boolean = true): SnowflakePlan = - SnowflakePlan(Seq(Query(sql)), sql, session, sourcePlan, supportAsyncMode) + supportAsyncMode: Boolean = true, + params: Seq[Any] = Seq.empty): SnowflakePlan = + SnowflakePlan(Seq(Query(sql, params)), sql, session, sourcePlan, supportAsyncMode) def largeLocalRelationPlan( output: Seq[Attribute], @@ -764,7 +770,8 @@ class SnowflakePlanBuilder(session: Session) extends Logging { private[snowpark] class Query( val sql: String, val queryIdPlaceHolder: String, - val isDDLOnTempObject: Boolean) + val isDDLOnTempObject: Boolean, + val params: Seq[Any]) extends Logging { logDebug(s"Creating a new Query: $sql ID: $queryIdPlaceHolder") override def toString: String = sql @@ -776,7 +783,7 @@ private[snowpark] class Query( placeholders.foreach { case (holder, id) => finalQuery = finalQuery.replaceAll(holder, id) } - val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters) + val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters, params) placeholders += (queryIdPlaceHolder -> queryId) queryId } @@ -795,7 +802,8 @@ private[snowpark] class Query( finalQuery, !returnIterator, returnIterator, - conn.getStatementParameters(isDDLOnTempObject, statementParameters)) + conn.getStatementParameters(isDDLOnTempObject, statementParameters), + params) placeholders += (queryIdPlaceHolder -> result.queryId) result } @@ -806,7 +814,7 @@ private[snowpark] class BatchInsertQuery( override val queryIdPlaceHolder: String, attributes: Seq[Attribute], rows: Seq[Row]) - extends Query(sql, queryIdPlaceHolder, false) { + extends Query(sql, queryIdPlaceHolder, false, Seq.empty) { override def runQuery( conn: ServerConnection, placeholders: mutable.HashMap[String, String], @@ -832,11 +840,19 @@ object Query { s"query_id_place_holder_${Random.alphanumeric.take(10).mkString}" def apply(sql: String): Query = { - new Query(sql, placeHolder(), false) + new Query(sql, placeHolder(), false, Seq.empty) + } + + def apply(sql: String, params: Seq[Any]): Query = { + new Query(sql, placeHolder(), false, params) } def apply(sql: String, isDDLOnTempObject: Boolean): Query = { - new Query(sql, placeHolder(), isDDLOnTempObject) + new Query(sql, placeHolder(), isDDLOnTempObject, Seq.empty) + } + + def apply(sql: String, isDDLOnTempObject: Boolean, params: Seq[Any]): Query = { + new Query(sql, placeHolder(), isDDLOnTempObject, params) } def apply(sql: String, attributes: Seq[Attribute], rows: Seq[Row]): Query = { diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index ee00443c..4d192738 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -7,6 +7,8 @@ import com.snowflake.snowpark_java.types.StructField; import com.snowflake.snowpark_java.types.StructType; import java.sql.Connection; +import java.util.Arrays; + import net.snowflake.client.jdbc.SnowflakeConnection; import org.junit.Test; @@ -61,6 +63,19 @@ public void generator() { new Row[] {Row.create(1, 2), Row.create(1, 2), Row.create(1, 2)}); } + @Test + public void sql() { + checkAnswer( + getSession().sql("select * from values(1, 2),(3, 4) as t(a, b)"), + new Row[] {Row.create(1, 2), Row.create(3, 4)} + ); + + checkAnswer( + getSession().sql("select * from values(?, ?),(?, ?) as t(a, b)", 1, 2, 3, 4), + new Row[] {Row.create(1, 2), Row.create(3, 4)} + ); + } + @Test public void getSessionStage() { assert getSession().getSessionStage().contains("SNOWPARK_TEMP_STAGE"); diff --git a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala index 5269f4b2..ae62311e 100644 --- a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala @@ -156,4 +156,22 @@ class ServerConnectionSuite extends SNTestBase { } } + test("ServerConnection with binding parameters") { + val sql = "select * from values (?),(?),(?)" + val params = Seq(1, 2, 3) + + val statement = session.conn.connection.prepareStatement(sql) + params.zipWithIndex.foreach { + case (p, i) => statement.setObject(i + 1, p) + } + + val rs = statement.executeQuery() + assert(rs.eq(statement.getResultSet)) + rs.next() + assert(rs.getInt(1) == 1) + rs.next() + assert(rs.getInt(1) == 2) + rs.next() + assert(rs.getInt(1) == 3) + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala index a0ae3be0..1817a078 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala @@ -104,6 +104,37 @@ trait SqlSuite extends SNTestBase { new Directory(new File(outputPath)).deleteRecursively() } } + + + test("Run sql query with bindings") { + val df1 = session.sql("select * from values (?),(?),(?)", List(1, 2, 3)) + assert(df1.collect() sameElements Array[Row](Row(1), Row(2), Row(3))) + + val df2 = + session.sql( + "select variance(identifier(?)) from values(1,1),(1,2),(2,1),(2,2),(3,1),(3,2) as T(a, b)", + Seq("a")) + assert(df2.collect()(0).getDecimal(0).toString == "0.800000") + + val df3 = session + .sql("select * from values (?),(?),(?) as T(id)", Seq(1, 2, 3)) + .filter(col("id") < 3) + assert(df3.collect() sameElements Array[Row](Row(1), Row(2))) + + val df4 = session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1)) + val df5 = session.sql( + "select * from values (?,?),(?,?),(?,?) as T(a, b)", + List(1, 2, 2, 1, 4, 3)) + val df6 = df4.union(df5).filter(col("a") < 3) + assert(df6.collect() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) + + val df7 = df4.join(df5, Seq("a", "b"), "inner") + assert(df7.collect() sameElements Array[Row](Row(2, 1))) + + // Async result + assert(df1.async.collect().getResult() sameElements Array[Row](Row(1), Row(2), Row(3))) + assert(df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) + } } class EagerSqlSuite extends SqlSuite with EagerSession { @@ -184,6 +215,7 @@ class EagerSqlSuite extends SqlSuite with EagerSession { assertThrows[SnowflakeSQLException](session.sql("SHOW TABLE")) } } + class LazySqlSuite extends SqlSuite with LazySession { test("Run sql query") { val df1 = session.sql("select * from values (1),(2),(3)") From e7d8b99743babff7c5d96bca953522a743382b36 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Wed, 30 Oct 2024 18:34:51 +0000 Subject: [PATCH 2/4] Update linting --- .../java/com/snowflake/snowpark_java/Session.java | 3 +-- .../snowpark/internal/ServerConnection.scala | 8 ++++---- .../snowpark/internal/analyzer/SnowflakePlan.scala | 3 +-- .../snowflake/snowpark_test/JavaSessionSuite.java | 8 ++------ .../scala/com/snowflake/snowpark_test/SqlSuite.scala | 12 ++++++------ 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index 443518c6..192d9705 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -79,8 +79,7 @@ public DataFrame sql(String query) { * @since 0.8.0 */ public DataFrame sql(String query, Object... params) { - return new DataFrame( - session.sql(query, JavaUtils.objectArrayToSeq(params))); + return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(params))); } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index e855afb3..5fd3feb2 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -288,8 +288,7 @@ private[snowpark] class ServerConnection( true, false, getStatementParameters(isDDLOnTempObject = false, Map.empty), - Seq.empty - ).rows.get + Seq.empty).rows.get .map(r => r.getString(0).toLowerCase() + PackageNameDelimiter + r.getString(1).toLowerCase()) .toSet @@ -298,7 +297,7 @@ private[snowpark] class ServerConnection( statement: PreparedStatement, params: Seq[Any]): Unit = params.zipWithIndex.foreach { - case (p, i) => statement.setObject(i + 1, p) + case (p, i) => statement.setObject(i + 1, p) } private[snowflake] def setStatementParameters( @@ -883,7 +882,8 @@ private[snowpark] class ServerConnection( val statement = connection.prepareStatement(multipleStatements) try { // Note binding parameters only supported for single query - val bindingParameters = if (plan.queries.length == 1) plan.queries.last.params else Seq() + val bindingParameters = + if (plan.queries.length == 1) plan.queries.last.params else Seq() val statementParameters = getStatementParameters() + ("MULTI_STATEMENT_COUNT" -> plan.queries.size) setBindingParameters(statement, bindingParameters) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index 0c12b8f5..0a32e0a4 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -292,8 +292,7 @@ class SnowflakePlanBuilder(session: Session) extends Logging { selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++ selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query( sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql), - lastQueryLeft.params ++ lastQueryRight.params - ) + lastQueryLeft.params ++ lastQueryRight.params) val leftSchemaQuery = schemaValueStatement(selectLeft.attributes) val rightSchemaQuery = schemaValueStatement(selectRight.attributes) val schemaQuery = sqlGenerator(leftSchemaQuery, rightSchemaQuery) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index 4d192738..3fb4a626 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -7,8 +7,6 @@ import com.snowflake.snowpark_java.types.StructField; import com.snowflake.snowpark_java.types.StructType; import java.sql.Connection; -import java.util.Arrays; - import net.snowflake.client.jdbc.SnowflakeConnection; import org.junit.Test; @@ -67,13 +65,11 @@ public void generator() { public void sql() { checkAnswer( getSession().sql("select * from values(1, 2),(3, 4) as t(a, b)"), - new Row[] {Row.create(1, 2), Row.create(3, 4)} - ); + new Row[] {Row.create(1, 2), Row.create(3, 4)}); checkAnswer( getSession().sql("select * from values(?, ?),(?, ?) as t(a, b)", 1, 2, 3, 4), - new Row[] {Row.create(1, 2), Row.create(3, 4)} - ); + new Row[] {Row.create(1, 2), Row.create(3, 4)}); } @Test diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala index 1817a078..d0478606 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala @@ -105,7 +105,6 @@ trait SqlSuite extends SNTestBase { } } - test("Run sql query with bindings") { val df1 = session.sql("select * from values (?),(?),(?)", List(1, 2, 3)) assert(df1.collect() sameElements Array[Row](Row(1), Row(2), Row(3))) @@ -121,10 +120,10 @@ trait SqlSuite extends SNTestBase { .filter(col("id") < 3) assert(df3.collect() sameElements Array[Row](Row(1), Row(2))) - val df4 = session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1)) - val df5 = session.sql( - "select * from values (?,?),(?,?),(?,?) as T(a, b)", - List(1, 2, 2, 1, 4, 3)) + val df4 = + session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1)) + val df5 = + session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", List(1, 2, 2, 1, 4, 3)) val df6 = df4.union(df5).filter(col("a") < 3) assert(df6.collect() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) @@ -133,7 +132,8 @@ trait SqlSuite extends SNTestBase { // Async result assert(df1.async.collect().getResult() sameElements Array[Row](Row(1), Row(2), Row(3))) - assert(df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) + assert( + df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) } } From 60897151623f60d8d6d7777bb87f4e9b1423a8d1 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Fri, 1 Nov 2024 18:21:11 +0000 Subject: [PATCH 3/4] Update from code review comments --- .../com/snowflake/snowpark_java/Session.java | 2 +- .../snowpark/internal/ErrorMessage.scala | 6 ++++- .../snowpark/internal/ServerConnection.scala | 7 +++--- .../internal/analyzer/SnowflakePlan.scala | 25 ++++++++----------- .../snowpark/ServerConnectionSuite.scala | 22 ++++++++++++++++ 5 files changed, 42 insertions(+), 20 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index 192d9705..f523daae 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -76,7 +76,7 @@ public DataFrame sql(String query) { * @param query The SQL statement to execute. * @param params The binding parameters for SQL statement (optional) * @return A {@code DataFrame} object - * @since 0.8.0 + * @since 1.15.0 */ public DataFrame sql(String query, Object... params) { return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(params))); diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index ea14da1e..8cfead31 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -162,7 +162,8 @@ private[snowpark] object ErrorMessage { "0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.", "0426" -> "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.", "0427" -> "The query tag of the current session must be a valid JSON string. Current query tag: %s", - "0428" -> "Failed to serialize the query tag into a JSON string.") + "0428" -> "Failed to serialize the query tag into a JSON string.", + "0429" -> "Binding parameter not supported on multi-statement query.") // scalastyle:on /* @@ -421,6 +422,9 @@ private[snowpark] object ErrorMessage { def MISC_FAILED_TO_SERIALIZE_QUERY_TAG(): SnowparkClientException = createException("0428") + def BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED(): SnowparkClientException = + createException("0429") + /** * Create Snowpark client Exception. * diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index 5fd3feb2..3be7e864 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -871,7 +871,6 @@ private[snowpark] class ServerConnection( logDebug(s"""execute plan in async mode: |----------SNOW----------- - | |$plan |------------------------- |""".stripMargin) @@ -882,8 +881,10 @@ private[snowpark] class ServerConnection( val statement = connection.prepareStatement(multipleStatements) try { // Note binding parameters only supported for single query - val bindingParameters = - if (plan.queries.length == 1) plan.queries.last.params else Seq() + val bindingParameters = plan.queries.map(_.params).flatten + if (plan.queries.length > 1 && bindingParameters.length > 0) { + throw ErrorMessage.BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED + } val statementParameters = getStatementParameters() + ("MULTI_STATEMENT_COUNT" -> plan.queries.size) setBindingParameters(statement, bindingParameters) diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index 0a32e0a4..db8e6c74 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -292,6 +292,7 @@ class SnowflakePlanBuilder(session: Session) extends Logging { selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++ selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query( sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql), + false, lastQueryLeft.params ++ lastQueryRight.params) val leftSchemaQuery = schemaValueStatement(selectLeft.attributes) val rightSchemaQuery = schemaValueStatement(selectRight.attributes) @@ -315,7 +316,10 @@ class SnowflakePlanBuilder(session: Session) extends Logging { val queries: Seq[Query] = selectChildren .map(c => c.queries.slice(0, c.queries.length - 1)) - .reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql)), params) + .reduce(_ ++ _) :+ Query( + sqlGenerator(selectChildren.map(_.queries.last.sql)), + false, + params) val schemaQueries = children.map(c => schemaValueStatement(c.attributes)) val schemaQuery = sqlGenerator(schemaQueries) @@ -329,7 +333,7 @@ class SnowflakePlanBuilder(session: Session) extends Logging { sourcePlan: Option[LogicalPlan], supportAsyncMode: Boolean = true, params: Seq[Any] = Seq.empty): SnowflakePlan = - SnowflakePlan(Seq(Query(sql, params)), sql, session, sourcePlan, supportAsyncMode) + SnowflakePlan(Seq(Query(sql, false, params)), sql, session, sourcePlan, supportAsyncMode) def largeLocalRelationPlan( output: Seq[Attribute], @@ -838,19 +842,10 @@ object Query { private def placeHolder(): String = s"query_id_place_holder_${Random.alphanumeric.take(10).mkString}" - def apply(sql: String): Query = { - new Query(sql, placeHolder(), false, Seq.empty) - } - - def apply(sql: String, params: Seq[Any]): Query = { - new Query(sql, placeHolder(), false, params) - } - - def apply(sql: String, isDDLOnTempObject: Boolean): Query = { - new Query(sql, placeHolder(), isDDLOnTempObject, Seq.empty) - } - - def apply(sql: String, isDDLOnTempObject: Boolean, params: Seq[Any]): Query = { + def apply( + sql: String, + isDDLOnTempObject: Boolean = false, + params: Seq[Any] = Seq.empty): Query = { new Query(sql, placeHolder(), isDDLOnTempObject, params) } diff --git a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala index ae62311e..bf2509bf 100644 --- a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala @@ -173,5 +173,27 @@ class ServerConnectionSuite extends SNTestBase { assert(rs.getInt(1) == 2) rs.next() assert(rs.getInt(1) == 3) + + // Test multi-statement with binding parameters not supported.implicit class + val tableName = randomName() + val queries = Seq( + Query(s"create or replace temporary table $tableName (c1 int, c2 string)"), + Query( + s"insert into $tableName values (?, ?), (?, ?)", + false, + Seq(1, "abc", 123, "dfdffdfdf")), + Query("select SYSTEM$WAIT(?)", false, Seq(2)), + Query(s"select max(c1) from $tableName")) + val plan = new SnowflakePlan( + queries, + schemaValueStatement(Seq(Attribute("C1", LongType))), + Seq(Query(s"drop table if exists $tableName", true)), + session, + None, + supportAsyncMode = true) + + assertThrows[SnowparkClientException] { + session.conn.executeAsync(plan) + } } } From 3a53743e89c93b0fecdce20af70d48bc8ba1fcd7 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Mon, 4 Nov 2024 20:33:58 +0000 Subject: [PATCH 4/4] Add ErrorMessageSuite test --- .../scala/com/snowflake/snowpark/ErrorMessageSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 0ad6d802..243819b1 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -886,4 +886,11 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string.")) } + + test("BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED") { + val ex = ErrorMessage.BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0429"))) + assert(ex.message.startsWith( + "Error Code: 0429, Error message: Binding parameter not supported on multi-statement query.")) + } }