diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index c2f4ef6d..f523daae 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -65,7 +65,21 @@ public static SessionBuilder builder() { * @since 0.8.0 */ public DataFrame sql(String query) { - return new DataFrame(session.sql(query)); + return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(new Object[0]))); + } + + /** + * Returns a new {@code DataFrame} representing the results of a SQL query. + * + *

You can use this method to execute an arbitrary SQL statement. + * + * @param query The SQL statement to execute. + * @param params The binding parameters for SQL statement (optional) + * @return A {@code DataFrame} object + * @since 1.15.0 + */ + public DataFrame sql(String query, Object... params) { + return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(params))); } /** diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 1beac1c0..8a670b90 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -945,12 +945,13 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log * You can use this method to execute an arbitrary SQL statement. * * @param query The SQL statement to execute. + * @param params for bind variables in SQL statement. * @return A [[DataFrame]] * @since 0.1.0 */ - def sql(query: String): DataFrame = { + def sql(query: String, params: Seq[Any] = Seq.empty): DataFrame = { // PUT and GET command cannot be executed in async mode - DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query))) + DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query), params)) } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index ea14da1e..8cfead31 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -162,7 +162,8 @@ private[snowpark] object ErrorMessage { "0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.", "0426" -> "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.", "0427" -> "The query tag of the current session must be a valid JSON string. Current query tag: %s", - "0428" -> "Failed to serialize the query tag into a JSON string.") + "0428" -> "Failed to serialize the query tag into a JSON string.", + "0429" -> "Binding parameter not supported on multi-statement query.") // scalastyle:on /* @@ -421,6 +422,9 @@ private[snowpark] object ErrorMessage { def MISC_FAILED_TO_SERIALIZE_QUERY_TAG(): SnowparkClientException = createException("0428") + def BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED(): SnowparkClientException = + createException("0429") + /** * Create Snowpark client Exception. * diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index a2281925..3be7e864 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -30,6 +30,7 @@ import net.snowflake.client.jdbc.{ SnowflakeBaseResultSet, SnowflakeConnectString, SnowflakeConnectionV1, + SnowflakePreparedStatement, SnowflakeReauthenticationRequest, SnowflakeResultSet, SnowflakeResultSetMetaData, @@ -286,11 +287,19 @@ private[snowpark] class ServerConnection( s"where language = 'java'", true, false, - getStatementParameters(isDDLOnTempObject = false, Map.empty)).rows.get + getStatementParameters(isDDLOnTempObject = false, Map.empty), + Seq.empty).rows.get .map(r => r.getString(0).toLowerCase() + PackageNameDelimiter + r.getString(1).toLowerCase()) .toSet + private[snowflake] def setBindingParameters( + statement: PreparedStatement, + params: Seq[Any]): Unit = + params.zipWithIndex.foreach { + case (p, i) => statement.setObject(i + 1, p) + } + private[snowflake] def setStatementParameters( statement: Statement, parameters: Map[String, Any]): Unit = @@ -438,12 +447,14 @@ private[snowpark] class ServerConnection( def runQuery( query: String, isDDLOnTempObject: Boolean = false, - statementParameters: Map[String, Any] = Map.empty): String = + statementParameters: Map[String, Any] = Map.empty, + params: Seq[Any] = Seq.empty): String = runQueryGetResult( query, returnRows = false, returnIterator = false, - getStatementParameters(isDDLOnTempObject, statementParameters)).queryId + getStatementParameters(isDDLOnTempObject, statementParameters), + params).queryId // Run the query and return the queryID when the caller doesn't need the ResultSet def runQueryGetRows( @@ -453,7 +464,8 @@ private[snowpark] class ServerConnection( query, returnRows = true, returnIterator = false, - getStatementParameters(isDDLOnTempObject = false, statementParameters)).rows.get + getStatementParameters(isDDLOnTempObject = false, statementParameters), + Seq.empty).rows.get // Run the query to get query result. // 1. If the caller needs to get Iterator[Row], the internal JDBC ResultSet and Statement @@ -466,11 +478,13 @@ private[snowpark] class ServerConnection( query: String, returnRows: Boolean, returnIterator: Boolean, - statementParameters: Map[String, Any]): QueryResult = + statementParameters: Map[String, Any], + params: Seq[Any]): QueryResult = withValidConnection { var statement: PreparedStatement = null try { statement = connection.prepareStatement(query) + setBindingParameters(statement, params) setStatementParameters(statement, statementParameters) val rs = statement.executeQuery() val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID @@ -862,15 +876,20 @@ private[snowpark] class ServerConnection( |""".stripMargin) // use try finally to ensure postActions is always run - val statement = connection.createStatement() + val queries = plan.queries.map(_.sql) + val multipleStatements = queries.mkString("; ") + val statement = connection.prepareStatement(multipleStatements) try { - val queries = plan.queries.map(_.sql) - val multipleStatements = queries.mkString("; ") + // Note binding parameters only supported for single query + val bindingParameters = plan.queries.map(_.params).flatten + if (plan.queries.length > 1 && bindingParameters.length > 0) { + throw ErrorMessage.BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED + } val statementParameters = getStatementParameters() + ("MULTI_STATEMENT_COUNT" -> plan.queries.size) + setBindingParameters(statement, bindingParameters) setStatementParameters(statement, statementParameters) - val rs = - statement.asInstanceOf[SnowflakeStatement].executeAsyncQuery(multipleStatements) + val rs = statement.asInstanceOf[SnowflakePreparedStatement].executeAsyncQuery() val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID if (actionID <= plan.session.getLastCanceledID) { throw ErrorMessage.MISC_QUERY_IS_CANCELLED() diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index a3218758..db8e6c74 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -266,8 +266,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging { schemaQuery: Option[String], isDDLOnTempObject: Boolean): SnowflakePlan = wrapException(child) { val selectChild = addResultScanIfNotSelect(child) + val lastQuery = selectChild.queries.last val queries: Seq[Query] = selectChild.queries.slice(0, selectChild.queries.length - 1) ++ - multipleSqlGenerator(selectChild.queries.last.sql).map(Query(_, isDDLOnTempObject)) + multipleSqlGenerator(lastQuery.sql).map(Query(_, isDDLOnTempObject, lastQuery.params)) val newSchemaQuery = schemaQuery.getOrElse(multipleSqlGenerator(child.schemaQuery).last) SnowflakePlan( queries, @@ -284,15 +285,18 @@ class SnowflakePlanBuilder(session: Session) extends Logging { right: SnowflakePlan, sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(left, right) { val selectLeft = addResultScanIfNotSelect(left) + val lastQueryLeft = selectLeft.queries.last val selectRight = addResultScanIfNotSelect(right) + val lastQueryRight = selectRight.queries.last val queries: Seq[Query] = selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++ selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query( - sqlGenerator(selectLeft.queries.last.sql, selectRight.queries.last.sql)) + sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql), + false, + lastQueryLeft.params ++ lastQueryRight.params) val leftSchemaQuery = schemaValueStatement(selectLeft.attributes) val rightSchemaQuery = schemaValueStatement(selectRight.attributes) val schemaQuery = sqlGenerator(leftSchemaQuery, rightSchemaQuery) - val commonColumn = selectLeft.aliasMap.keySet.intersect(selectRight.aliasMap.keySet) val supportAsyncMode = selectLeft.supportAsyncMode && selectRight.supportAsyncMode SnowflakePlan( queries, @@ -308,10 +312,14 @@ class SnowflakePlanBuilder(session: Session) extends Logging { children: Seq[SnowflakePlan], sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(children: _*) { val selectChildren = children.map(addResultScanIfNotSelect) + val params: Seq[Any] = selectChildren.map(_.queries.last.params).flatten val queries: Seq[Query] = selectChildren .map(c => c.queries.slice(0, c.queries.length - 1)) - .reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql))) + .reduce(_ ++ _) :+ Query( + sqlGenerator(selectChildren.map(_.queries.last.sql)), + false, + params) val schemaQueries = children.map(c => schemaValueStatement(c.attributes)) val schemaQuery = sqlGenerator(schemaQueries) @@ -323,8 +331,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging { def query( sql: String, sourcePlan: Option[LogicalPlan], - supportAsyncMode: Boolean = true): SnowflakePlan = - SnowflakePlan(Seq(Query(sql)), sql, session, sourcePlan, supportAsyncMode) + supportAsyncMode: Boolean = true, + params: Seq[Any] = Seq.empty): SnowflakePlan = + SnowflakePlan(Seq(Query(sql, false, params)), sql, session, sourcePlan, supportAsyncMode) def largeLocalRelationPlan( output: Seq[Attribute], @@ -764,7 +773,8 @@ class SnowflakePlanBuilder(session: Session) extends Logging { private[snowpark] class Query( val sql: String, val queryIdPlaceHolder: String, - val isDDLOnTempObject: Boolean) + val isDDLOnTempObject: Boolean, + val params: Seq[Any]) extends Logging { logDebug(s"Creating a new Query: $sql ID: $queryIdPlaceHolder") override def toString: String = sql @@ -776,7 +786,7 @@ private[snowpark] class Query( placeholders.foreach { case (holder, id) => finalQuery = finalQuery.replaceAll(holder, id) } - val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters) + val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters, params) placeholders += (queryIdPlaceHolder -> queryId) queryId } @@ -795,7 +805,8 @@ private[snowpark] class Query( finalQuery, !returnIterator, returnIterator, - conn.getStatementParameters(isDDLOnTempObject, statementParameters)) + conn.getStatementParameters(isDDLOnTempObject, statementParameters), + params) placeholders += (queryIdPlaceHolder -> result.queryId) result } @@ -806,7 +817,7 @@ private[snowpark] class BatchInsertQuery( override val queryIdPlaceHolder: String, attributes: Seq[Attribute], rows: Seq[Row]) - extends Query(sql, queryIdPlaceHolder, false) { + extends Query(sql, queryIdPlaceHolder, false, Seq.empty) { override def runQuery( conn: ServerConnection, placeholders: mutable.HashMap[String, String], @@ -831,12 +842,11 @@ object Query { private def placeHolder(): String = s"query_id_place_holder_${Random.alphanumeric.take(10).mkString}" - def apply(sql: String): Query = { - new Query(sql, placeHolder(), false) - } - - def apply(sql: String, isDDLOnTempObject: Boolean): Query = { - new Query(sql, placeHolder(), isDDLOnTempObject) + def apply( + sql: String, + isDDLOnTempObject: Boolean = false, + params: Seq[Any] = Seq.empty): Query = { + new Query(sql, placeHolder(), isDDLOnTempObject, params) } def apply(sql: String, attributes: Seq[Attribute], rows: Seq[Row]): Query = { diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java index ee00443c..3fb4a626 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java @@ -61,6 +61,17 @@ public void generator() { new Row[] {Row.create(1, 2), Row.create(1, 2), Row.create(1, 2)}); } + @Test + public void sql() { + checkAnswer( + getSession().sql("select * from values(1, 2),(3, 4) as t(a, b)"), + new Row[] {Row.create(1, 2), Row.create(3, 4)}); + + checkAnswer( + getSession().sql("select * from values(?, ?),(?, ?) as t(a, b)", 1, 2, 3, 4), + new Row[] {Row.create(1, 2), Row.create(3, 4)}); + } + @Test public void getSessionStage() { assert getSession().getSessionStage().contains("SNOWPARK_TEMP_STAGE"); diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 0ad6d802..243819b1 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -886,4 +886,11 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string.")) } + + test("BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED") { + val ex = ErrorMessage.BINDING_PARAMETER_MULTI_STATEMENT_NOT_SUPPORTED() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0429"))) + assert(ex.message.startsWith( + "Error Code: 0429, Error message: Binding parameter not supported on multi-statement query.")) + } } diff --git a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala index 5269f4b2..bf2509bf 100644 --- a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala @@ -156,4 +156,44 @@ class ServerConnectionSuite extends SNTestBase { } } + test("ServerConnection with binding parameters") { + val sql = "select * from values (?),(?),(?)" + val params = Seq(1, 2, 3) + + val statement = session.conn.connection.prepareStatement(sql) + params.zipWithIndex.foreach { + case (p, i) => statement.setObject(i + 1, p) + } + + val rs = statement.executeQuery() + assert(rs.eq(statement.getResultSet)) + rs.next() + assert(rs.getInt(1) == 1) + rs.next() + assert(rs.getInt(1) == 2) + rs.next() + assert(rs.getInt(1) == 3) + + // Test multi-statement with binding parameters not supported.implicit class + val tableName = randomName() + val queries = Seq( + Query(s"create or replace temporary table $tableName (c1 int, c2 string)"), + Query( + s"insert into $tableName values (?, ?), (?, ?)", + false, + Seq(1, "abc", 123, "dfdffdfdf")), + Query("select SYSTEM$WAIT(?)", false, Seq(2)), + Query(s"select max(c1) from $tableName")) + val plan = new SnowflakePlan( + queries, + schemaValueStatement(Seq(Attribute("C1", LongType))), + Seq(Query(s"drop table if exists $tableName", true)), + session, + None, + supportAsyncMode = true) + + assertThrows[SnowparkClientException] { + session.conn.executeAsync(plan) + } + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala index a0ae3be0..d0478606 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala @@ -104,6 +104,37 @@ trait SqlSuite extends SNTestBase { new Directory(new File(outputPath)).deleteRecursively() } } + + test("Run sql query with bindings") { + val df1 = session.sql("select * from values (?),(?),(?)", List(1, 2, 3)) + assert(df1.collect() sameElements Array[Row](Row(1), Row(2), Row(3))) + + val df2 = + session.sql( + "select variance(identifier(?)) from values(1,1),(1,2),(2,1),(2,2),(3,1),(3,2) as T(a, b)", + Seq("a")) + assert(df2.collect()(0).getDecimal(0).toString == "0.800000") + + val df3 = session + .sql("select * from values (?),(?),(?) as T(id)", Seq(1, 2, 3)) + .filter(col("id") < 3) + assert(df3.collect() sameElements Array[Row](Row(1), Row(2))) + + val df4 = + session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1)) + val df5 = + session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", List(1, 2, 2, 1, 4, 3)) + val df6 = df4.union(df5).filter(col("a") < 3) + assert(df6.collect() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) + + val df7 = df4.join(df5, Seq("a", "b"), "inner") + assert(df7.collect() sameElements Array[Row](Row(2, 1))) + + // Async result + assert(df1.async.collect().getResult() sameElements Array[Row](Row(1), Row(2), Row(3))) + assert( + df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2))) + } } class EagerSqlSuite extends SqlSuite with EagerSession { @@ -184,6 +215,7 @@ class EagerSqlSuite extends SqlSuite with EagerSession { assertThrows[SnowflakeSQLException](session.sql("SHOW TABLE")) } } + class LazySqlSuite extends SqlSuite with LazySession { test("Run sql query") { val df1 = session.sql("select * from values (1),(2),(3)")