Skip to content

Commit

Permalink
SNOW-1022196 Support binding parameters for snowpark java api
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-evandenberg committed Oct 30, 2024
1 parent e57e78a commit e565db0
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 25 deletions.
17 changes: 16 additions & 1 deletion src/main/java/com/snowflake/snowpark_java/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,22 @@ public static SessionBuilder builder() {
* @since 0.8.0
*/
public DataFrame sql(String query) {
return new DataFrame(session.sql(query));
return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(new Object[0])));
}

/**
* Returns a new {@code DataFrame} representing the results of a SQL query.
*
* <p>You can use this method to execute an arbitrary SQL statement.
*
* @param query The SQL statement to execute.
* @param params The binding parameters for SQL statement (optional)
* @return A {@code DataFrame} object
* @since 0.8.0
*/
public DataFrame sql(String query, Object... params) {
return new DataFrame(
session.sql(query, JavaUtils.objectArrayToSeq(params)));
}

/**
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,13 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* You can use this method to execute an arbitrary SQL statement.
*
* @param query The SQL statement to execute.
* @param params for bind variables in SQL statement.
* @return A [[DataFrame]]
* @since 0.1.0
*/
def sql(query: String): DataFrame = {
def sql(query: String, params: Seq[Any] = Seq.empty): DataFrame = {
// PUT and GET command cannot be executed in async mode
DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query)))
DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query), params))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import net.snowflake.client.jdbc.{
SnowflakeBaseResultSet,
SnowflakeConnectString,
SnowflakeConnectionV1,
SnowflakePreparedStatement,
SnowflakeReauthenticationRequest,
SnowflakeResultSet,
SnowflakeResultSetMetaData,
Expand Down Expand Up @@ -286,11 +287,20 @@ private[snowpark] class ServerConnection(
s"where language = 'java'",
true,
false,
getStatementParameters(isDDLOnTempObject = false, Map.empty)).rows.get
getStatementParameters(isDDLOnTempObject = false, Map.empty),
Seq.empty
).rows.get
.map(r =>
r.getString(0).toLowerCase() + PackageNameDelimiter + r.getString(1).toLowerCase())
.toSet

private[snowflake] def setBindingParameters(
statement: PreparedStatement,
params: Seq[Any]): Unit =
params.zipWithIndex.foreach {
case (p, i) => statement.setObject(i + 1, p)
}

private[snowflake] def setStatementParameters(
statement: Statement,
parameters: Map[String, Any]): Unit =
Expand Down Expand Up @@ -438,12 +448,14 @@ private[snowpark] class ServerConnection(
def runQuery(
query: String,
isDDLOnTempObject: Boolean = false,
statementParameters: Map[String, Any] = Map.empty): String =
statementParameters: Map[String, Any] = Map.empty,
params: Seq[Any] = Seq.empty): String =
runQueryGetResult(
query,
returnRows = false,
returnIterator = false,
getStatementParameters(isDDLOnTempObject, statementParameters)).queryId
getStatementParameters(isDDLOnTempObject, statementParameters),
params).queryId

// Run the query and return the queryID when the caller doesn't need the ResultSet
def runQueryGetRows(
Expand All @@ -453,7 +465,8 @@ private[snowpark] class ServerConnection(
query,
returnRows = true,
returnIterator = false,
getStatementParameters(isDDLOnTempObject = false, statementParameters)).rows.get
getStatementParameters(isDDLOnTempObject = false, statementParameters),
Seq.empty).rows.get

// Run the query to get query result.
// 1. If the caller needs to get Iterator[Row], the internal JDBC ResultSet and Statement
Expand All @@ -466,11 +479,13 @@ private[snowpark] class ServerConnection(
query: String,
returnRows: Boolean,
returnIterator: Boolean,
statementParameters: Map[String, Any]): QueryResult =
statementParameters: Map[String, Any],
params: Seq[Any]): QueryResult =
withValidConnection {
var statement: PreparedStatement = null
try {
statement = connection.prepareStatement(query)
setBindingParameters(statement, params)
setStatementParameters(statement, statementParameters)
val rs = statement.executeQuery()
val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID
Expand Down Expand Up @@ -857,20 +872,23 @@ private[snowpark] class ServerConnection(

logDebug(s"""execute plan in async mode:
|----------SNOW-----------
|
|$plan
|-------------------------
|""".stripMargin)

// use try finally to ensure postActions is always run
val statement = connection.createStatement()
val queries = plan.queries.map(_.sql)
val multipleStatements = queries.mkString("; ")
val statement = connection.prepareStatement(multipleStatements)
try {
val queries = plan.queries.map(_.sql)
val multipleStatements = queries.mkString("; ")
// Note binding parameters only supported for single query
val bindingParameters = if (plan.queries.length == 1) plan.queries.last.params else Seq()
val statementParameters = getStatementParameters() +
("MULTI_STATEMENT_COUNT" -> plan.queries.size)
setBindingParameters(statement, bindingParameters)
setStatementParameters(statement, statementParameters)
val rs =
statement.asInstanceOf[SnowflakeStatement].executeAsyncQuery(multipleStatements)
val rs = statement.asInstanceOf[SnowflakePreparedStatement].executeAsyncQuery()
val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID
if (actionID <= plan.session.getLastCanceledID) {
throw ErrorMessage.MISC_QUERY_IS_CANCELLED()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
schemaQuery: Option[String],
isDDLOnTempObject: Boolean): SnowflakePlan = wrapException(child) {
val selectChild = addResultScanIfNotSelect(child)
val lastQuery = selectChild.queries.last
val queries: Seq[Query] = selectChild.queries.slice(0, selectChild.queries.length - 1) ++
multipleSqlGenerator(selectChild.queries.last.sql).map(Query(_, isDDLOnTempObject))
multipleSqlGenerator(lastQuery.sql).map(Query(_, isDDLOnTempObject, lastQuery.params))
val newSchemaQuery = schemaQuery.getOrElse(multipleSqlGenerator(child.schemaQuery).last)
SnowflakePlan(
queries,
Expand All @@ -284,15 +285,18 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
right: SnowflakePlan,
sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(left, right) {
val selectLeft = addResultScanIfNotSelect(left)
val lastQueryLeft = selectLeft.queries.last
val selectRight = addResultScanIfNotSelect(right)
val lastQueryRight = selectRight.queries.last
val queries: Seq[Query] =
selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++
selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query(
sqlGenerator(selectLeft.queries.last.sql, selectRight.queries.last.sql))
sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql),
lastQueryLeft.params ++ lastQueryRight.params
)
val leftSchemaQuery = schemaValueStatement(selectLeft.attributes)
val rightSchemaQuery = schemaValueStatement(selectRight.attributes)
val schemaQuery = sqlGenerator(leftSchemaQuery, rightSchemaQuery)
val commonColumn = selectLeft.aliasMap.keySet.intersect(selectRight.aliasMap.keySet)
val supportAsyncMode = selectLeft.supportAsyncMode && selectRight.supportAsyncMode
SnowflakePlan(
queries,
Expand All @@ -308,10 +312,11 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
children: Seq[SnowflakePlan],
sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(children: _*) {
val selectChildren = children.map(addResultScanIfNotSelect)
val params: Seq[Any] = selectChildren.map(_.queries.last.params).flatten
val queries: Seq[Query] =
selectChildren
.map(c => c.queries.slice(0, c.queries.length - 1))
.reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql)))
.reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql)), params)

val schemaQueries = children.map(c => schemaValueStatement(c.attributes))
val schemaQuery = sqlGenerator(schemaQueries)
Expand All @@ -323,8 +328,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
def query(
sql: String,
sourcePlan: Option[LogicalPlan],
supportAsyncMode: Boolean = true): SnowflakePlan =
SnowflakePlan(Seq(Query(sql)), sql, session, sourcePlan, supportAsyncMode)
supportAsyncMode: Boolean = true,
params: Seq[Any] = Seq.empty): SnowflakePlan =
SnowflakePlan(Seq(Query(sql, params)), sql, session, sourcePlan, supportAsyncMode)

def largeLocalRelationPlan(
output: Seq[Attribute],
Expand Down Expand Up @@ -764,7 +770,8 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
private[snowpark] class Query(
val sql: String,
val queryIdPlaceHolder: String,
val isDDLOnTempObject: Boolean)
val isDDLOnTempObject: Boolean,
val params: Seq[Any])
extends Logging {
logDebug(s"Creating a new Query: $sql ID: $queryIdPlaceHolder")
override def toString: String = sql
Expand All @@ -776,7 +783,7 @@ private[snowpark] class Query(
placeholders.foreach {
case (holder, id) => finalQuery = finalQuery.replaceAll(holder, id)
}
val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters)
val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters, params)
placeholders += (queryIdPlaceHolder -> queryId)
queryId
}
Expand All @@ -795,7 +802,8 @@ private[snowpark] class Query(
finalQuery,
!returnIterator,
returnIterator,
conn.getStatementParameters(isDDLOnTempObject, statementParameters))
conn.getStatementParameters(isDDLOnTempObject, statementParameters),
params)
placeholders += (queryIdPlaceHolder -> result.queryId)
result
}
Expand All @@ -806,7 +814,7 @@ private[snowpark] class BatchInsertQuery(
override val queryIdPlaceHolder: String,
attributes: Seq[Attribute],
rows: Seq[Row])
extends Query(sql, queryIdPlaceHolder, false) {
extends Query(sql, queryIdPlaceHolder, false, Seq.empty) {
override def runQuery(
conn: ServerConnection,
placeholders: mutable.HashMap[String, String],
Expand All @@ -832,11 +840,19 @@ object Query {
s"query_id_place_holder_${Random.alphanumeric.take(10).mkString}"

def apply(sql: String): Query = {
new Query(sql, placeHolder(), false)
new Query(sql, placeHolder(), false, Seq.empty)
}

def apply(sql: String, params: Seq[Any]): Query = {
new Query(sql, placeHolder(), false, params)
}

def apply(sql: String, isDDLOnTempObject: Boolean): Query = {
new Query(sql, placeHolder(), isDDLOnTempObject)
new Query(sql, placeHolder(), isDDLOnTempObject, Seq.empty)
}

def apply(sql: String, isDDLOnTempObject: Boolean, params: Seq[Any]): Query = {
new Query(sql, placeHolder(), isDDLOnTempObject, params)
}

def apply(sql: String, attributes: Seq[Attribute], rows: Seq[Row]): Query = {
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import com.snowflake.snowpark_java.types.StructField;
import com.snowflake.snowpark_java.types.StructType;
import java.sql.Connection;
import java.util.Arrays;

import net.snowflake.client.jdbc.SnowflakeConnection;
import org.junit.Test;

Expand Down Expand Up @@ -61,6 +63,19 @@ public void generator() {
new Row[] {Row.create(1, 2), Row.create(1, 2), Row.create(1, 2)});
}

@Test
public void sql() {
checkAnswer(
getSession().sql("select * from values(1, 2),(3, 4) as t(a, b)"),
new Row[] {Row.create(1, 2), Row.create(3, 4)}
);

checkAnswer(
getSession().sql("select * from values(?, ?),(?, ?) as t(a, b)", 1, 2, 3, 4),
new Row[] {Row.create(1, 2), Row.create(3, 4)}
);
}

@Test
public void getSessionStage() {
assert getSession().getSessionStage().contains("SNOWPARK_TEMP_STAGE");
Expand Down
18 changes: 18 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,22 @@ class ServerConnectionSuite extends SNTestBase {
}
}

test("ServerConnection with binding parameters") {
val sql = "select * from values (?),(?),(?)"
val params = Seq(1, 2, 3)

val statement = session.conn.connection.prepareStatement(sql)
params.zipWithIndex.foreach {
case (p, i) => statement.setObject(i + 1, p)
}

val rs = statement.executeQuery()
assert(rs.eq(statement.getResultSet))
rs.next()
assert(rs.getInt(1) == 1)
rs.next()
assert(rs.getInt(1) == 2)
rs.next()
assert(rs.getInt(1) == 3)
}
}
32 changes: 32 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,37 @@ trait SqlSuite extends SNTestBase {
new Directory(new File(outputPath)).deleteRecursively()
}
}


test("Run sql query with bindings") {
val df1 = session.sql("select * from values (?),(?),(?)", List(1, 2, 3))
assert(df1.collect() sameElements Array[Row](Row(1), Row(2), Row(3)))

val df2 =
session.sql(
"select variance(identifier(?)) from values(1,1),(1,2),(2,1),(2,2),(3,1),(3,2) as T(a, b)",
Seq("a"))
assert(df2.collect()(0).getDecimal(0).toString == "0.800000")

val df3 = session
.sql("select * from values (?),(?),(?) as T(id)", Seq(1, 2, 3))
.filter(col("id") < 3)
assert(df3.collect() sameElements Array[Row](Row(1), Row(2)))

val df4 = session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1))
val df5 = session.sql(
"select * from values (?,?),(?,?),(?,?) as T(a, b)",
List(1, 2, 2, 1, 4, 3))
val df6 = df4.union(df5).filter(col("a") < 3)
assert(df6.collect() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2)))

val df7 = df4.join(df5, Seq("a", "b"), "inner")
assert(df7.collect() sameElements Array[Row](Row(2, 1)))

// Async result
assert(df1.async.collect().getResult() sameElements Array[Row](Row(1), Row(2), Row(3)))
assert(df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2)))
}
}

class EagerSqlSuite extends SqlSuite with EagerSession {
Expand Down Expand Up @@ -184,6 +215,7 @@ class EagerSqlSuite extends SqlSuite with EagerSession {
assertThrows[SnowflakeSQLException](session.sql("SHOW TABLE"))
}
}

class LazySqlSuite extends SqlSuite with LazySession {
test("Run sql query") {
val df1 = session.sql("select * from values (1),(2),(3)")
Expand Down

0 comments on commit e565db0

Please sign in to comment.