Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1022196 Support binding parameters for snowpark java api #171

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/main/java/com/snowflake/snowpark_java/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,21 @@ public static SessionBuilder builder() {
* @since 0.8.0
*/
public DataFrame sql(String query) {
return new DataFrame(session.sql(query));
return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(new Object[0])));
}

/**
* Returns a new {@code DataFrame} representing the results of a SQL query.
*
* <p>You can use this method to execute an arbitrary SQL statement.
*
* @param query The SQL statement to execute.
* @param params The binding parameters for SQL statement (optional)
* @return A {@code DataFrame} object
* @since 0.8.0
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
*/
public DataFrame sql(String query, Object... params) {
return new DataFrame(session.sql(query, JavaUtils.objectArrayToSeq(params)));
}

/**
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,13 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* You can use this method to execute an arbitrary SQL statement.
*
* @param query The SQL statement to execute.
* @param params for bind variables in SQL statement.
* @return A [[DataFrame]]
* @since 0.1.0
*/
def sql(query: String): DataFrame = {
def sql(query: String, params: Seq[Any] = Seq.empty): DataFrame = {
// PUT and GET command cannot be executed in async mode
DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query)))
DataFrame(this, plans.query(query, None, !Utils.isPutOrGetCommand(query), params))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import net.snowflake.client.jdbc.{
SnowflakeBaseResultSet,
SnowflakeConnectString,
SnowflakeConnectionV1,
SnowflakePreparedStatement,
SnowflakeReauthenticationRequest,
SnowflakeResultSet,
SnowflakeResultSetMetaData,
Expand Down Expand Up @@ -286,11 +287,19 @@ private[snowpark] class ServerConnection(
s"where language = 'java'",
true,
false,
getStatementParameters(isDDLOnTempObject = false, Map.empty)).rows.get
getStatementParameters(isDDLOnTempObject = false, Map.empty),
Seq.empty).rows.get
.map(r =>
r.getString(0).toLowerCase() + PackageNameDelimiter + r.getString(1).toLowerCase())
.toSet

private[snowflake] def setBindingParameters(
statement: PreparedStatement,
params: Seq[Any]): Unit =
params.zipWithIndex.foreach {
case (p, i) => statement.setObject(i + 1, p)
}

private[snowflake] def setStatementParameters(
statement: Statement,
parameters: Map[String, Any]): Unit =
Expand Down Expand Up @@ -438,12 +447,14 @@ private[snowpark] class ServerConnection(
def runQuery(
query: String,
isDDLOnTempObject: Boolean = false,
statementParameters: Map[String, Any] = Map.empty): String =
statementParameters: Map[String, Any] = Map.empty,
params: Seq[Any] = Seq.empty): String =
runQueryGetResult(
query,
returnRows = false,
returnIterator = false,
getStatementParameters(isDDLOnTempObject, statementParameters)).queryId
getStatementParameters(isDDLOnTempObject, statementParameters),
params).queryId

// Run the query and return the queryID when the caller doesn't need the ResultSet
def runQueryGetRows(
Expand All @@ -453,7 +464,8 @@ private[snowpark] class ServerConnection(
query,
returnRows = true,
returnIterator = false,
getStatementParameters(isDDLOnTempObject = false, statementParameters)).rows.get
getStatementParameters(isDDLOnTempObject = false, statementParameters),
Seq.empty).rows.get

// Run the query to get query result.
// 1. If the caller needs to get Iterator[Row], the internal JDBC ResultSet and Statement
Expand All @@ -466,11 +478,13 @@ private[snowpark] class ServerConnection(
query: String,
returnRows: Boolean,
returnIterator: Boolean,
statementParameters: Map[String, Any]): QueryResult =
statementParameters: Map[String, Any],
params: Seq[Any]): QueryResult =
withValidConnection {
var statement: PreparedStatement = null
try {
statement = connection.prepareStatement(query)
setBindingParameters(statement, params)
setStatementParameters(statement, statementParameters)
val rs = statement.executeQuery()
val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID
Expand Down Expand Up @@ -857,20 +871,24 @@ private[snowpark] class ServerConnection(

logDebug(s"""execute plan in async mode:
|----------SNOW-----------
|
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
|$plan
|-------------------------
|""".stripMargin)

// use try finally to ensure postActions is always run
val statement = connection.createStatement()
val queries = plan.queries.map(_.sql)
val multipleStatements = queries.mkString("; ")
val statement = connection.prepareStatement(multipleStatements)
try {
val queries = plan.queries.map(_.sql)
val multipleStatements = queries.mkString("; ")
// Note binding parameters only supported for single query
val bindingParameters =
if (plan.queries.length == 1) plan.queries.last.params else Seq()
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved
val statementParameters = getStatementParameters() +
("MULTI_STATEMENT_COUNT" -> plan.queries.size)
setBindingParameters(statement, bindingParameters)
setStatementParameters(statement, statementParameters)
val rs =
statement.asInstanceOf[SnowflakeStatement].executeAsyncQuery(multipleStatements)
val rs = statement.asInstanceOf[SnowflakePreparedStatement].executeAsyncQuery()
val queryID = rs.asInstanceOf[SnowflakeResultSet].getQueryID
if (actionID <= plan.session.getLastCanceledID) {
throw ErrorMessage.MISC_QUERY_IS_CANCELLED()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
schemaQuery: Option[String],
isDDLOnTempObject: Boolean): SnowflakePlan = wrapException(child) {
val selectChild = addResultScanIfNotSelect(child)
val lastQuery = selectChild.queries.last
val queries: Seq[Query] = selectChild.queries.slice(0, selectChild.queries.length - 1) ++
multipleSqlGenerator(selectChild.queries.last.sql).map(Query(_, isDDLOnTempObject))
multipleSqlGenerator(lastQuery.sql).map(Query(_, isDDLOnTempObject, lastQuery.params))
val newSchemaQuery = schemaQuery.getOrElse(multipleSqlGenerator(child.schemaQuery).last)
SnowflakePlan(
queries,
Expand All @@ -284,15 +285,17 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
right: SnowflakePlan,
sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(left, right) {
val selectLeft = addResultScanIfNotSelect(left)
val lastQueryLeft = selectLeft.queries.last
val selectRight = addResultScanIfNotSelect(right)
val lastQueryRight = selectRight.queries.last
val queries: Seq[Query] =
selectLeft.queries.slice(0, selectLeft.queries.length - 1) ++
selectRight.queries.slice(0, selectRight.queries.length - 1) :+ Query(
sqlGenerator(selectLeft.queries.last.sql, selectRight.queries.last.sql))
sqlGenerator(lastQueryLeft.sql, lastQueryRight.sql),
lastQueryLeft.params ++ lastQueryRight.params)
val leftSchemaQuery = schemaValueStatement(selectLeft.attributes)
val rightSchemaQuery = schemaValueStatement(selectRight.attributes)
val schemaQuery = sqlGenerator(leftSchemaQuery, rightSchemaQuery)
val commonColumn = selectLeft.aliasMap.keySet.intersect(selectRight.aliasMap.keySet)
val supportAsyncMode = selectLeft.supportAsyncMode && selectRight.supportAsyncMode
SnowflakePlan(
queries,
Expand All @@ -308,10 +311,11 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
children: Seq[SnowflakePlan],
sourcePlan: Option[LogicalPlan]): SnowflakePlan = wrapException(children: _*) {
val selectChildren = children.map(addResultScanIfNotSelect)
val params: Seq[Any] = selectChildren.map(_.queries.last.params).flatten
val queries: Seq[Query] =
selectChildren
.map(c => c.queries.slice(0, c.queries.length - 1))
.reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql)))
.reduce(_ ++ _) :+ Query(sqlGenerator(selectChildren.map(_.queries.last.sql)), params)

val schemaQueries = children.map(c => schemaValueStatement(c.attributes))
val schemaQuery = sqlGenerator(schemaQueries)
Expand All @@ -323,8 +327,9 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
def query(
sql: String,
sourcePlan: Option[LogicalPlan],
supportAsyncMode: Boolean = true): SnowflakePlan =
SnowflakePlan(Seq(Query(sql)), sql, session, sourcePlan, supportAsyncMode)
supportAsyncMode: Boolean = true,
params: Seq[Any] = Seq.empty): SnowflakePlan =
SnowflakePlan(Seq(Query(sql, params)), sql, session, sourcePlan, supportAsyncMode)

def largeLocalRelationPlan(
output: Seq[Attribute],
Expand Down Expand Up @@ -764,7 +769,8 @@ class SnowflakePlanBuilder(session: Session) extends Logging {
private[snowpark] class Query(
val sql: String,
val queryIdPlaceHolder: String,
val isDDLOnTempObject: Boolean)
val isDDLOnTempObject: Boolean,
val params: Seq[Any])
extends Logging {
logDebug(s"Creating a new Query: $sql ID: $queryIdPlaceHolder")
override def toString: String = sql
Expand All @@ -776,7 +782,7 @@ private[snowpark] class Query(
placeholders.foreach {
case (holder, id) => finalQuery = finalQuery.replaceAll(holder, id)
}
val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters)
val queryId = conn.runQuery(finalQuery, isDDLOnTempObject, statementParameters, params)
placeholders += (queryIdPlaceHolder -> queryId)
queryId
}
Expand All @@ -795,7 +801,8 @@ private[snowpark] class Query(
finalQuery,
!returnIterator,
returnIterator,
conn.getStatementParameters(isDDLOnTempObject, statementParameters))
conn.getStatementParameters(isDDLOnTempObject, statementParameters),
params)
placeholders += (queryIdPlaceHolder -> result.queryId)
result
}
Expand All @@ -806,7 +813,7 @@ private[snowpark] class BatchInsertQuery(
override val queryIdPlaceHolder: String,
attributes: Seq[Attribute],
rows: Seq[Row])
extends Query(sql, queryIdPlaceHolder, false) {
extends Query(sql, queryIdPlaceHolder, false, Seq.empty) {
override def runQuery(
conn: ServerConnection,
placeholders: mutable.HashMap[String, String],
Expand All @@ -832,11 +839,19 @@ object Query {
s"query_id_place_holder_${Random.alphanumeric.take(10).mkString}"

def apply(sql: String): Query = {
new Query(sql, placeHolder(), false)
new Query(sql, placeHolder(), false, Seq.empty)
}

def apply(sql: String, params: Seq[Any]): Query = {
new Query(sql, placeHolder(), false, params)
}

def apply(sql: String, isDDLOnTempObject: Boolean): Query = {
new Query(sql, placeHolder(), isDDLOnTempObject)
new Query(sql, placeHolder(), isDDLOnTempObject, Seq.empty)
}

def apply(sql: String, isDDLOnTempObject: Boolean, params: Seq[Any]): Query = {
new Query(sql, placeHolder(), isDDLOnTempObject, params)
}

def apply(sql: String, attributes: Seq[Attribute], rows: Seq[Row]): Query = {
Expand Down
11 changes: 11 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaSessionSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ public void generator() {
new Row[] {Row.create(1, 2), Row.create(1, 2), Row.create(1, 2)});
}

@Test
public void sql() {
checkAnswer(
getSession().sql("select * from values(1, 2),(3, 4) as t(a, b)"),
new Row[] {Row.create(1, 2), Row.create(3, 4)});

checkAnswer(
getSession().sql("select * from values(?, ?),(?, ?) as t(a, b)", 1, 2, 3, 4),
new Row[] {Row.create(1, 2), Row.create(3, 4)});
}

@Test
public void getSessionStage() {
assert getSession().getSessionStage().contains("SNOWPARK_TEMP_STAGE");
Expand Down
18 changes: 18 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,22 @@ class ServerConnectionSuite extends SNTestBase {
}
}

test("ServerConnection with binding parameters") {
val sql = "select * from values (?),(?),(?)"
val params = Seq(1, 2, 3)

val statement = session.conn.connection.prepareStatement(sql)
params.zipWithIndex.foreach {
case (p, i) => statement.setObject(i + 1, p)
}

val rs = statement.executeQuery()
assert(rs.eq(statement.getResultSet))
rs.next()
assert(rs.getInt(1) == 1)
rs.next()
assert(rs.getInt(1) == 2)
rs.next()
assert(rs.getInt(1) == 3)
}
}
32 changes: 32 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,37 @@ trait SqlSuite extends SNTestBase {
new Directory(new File(outputPath)).deleteRecursively()
}
}

test("Run sql query with bindings") {
val df1 = session.sql("select * from values (?),(?),(?)", List(1, 2, 3))
assert(df1.collect() sameElements Array[Row](Row(1), Row(2), Row(3)))

val df2 =
session.sql(
"select variance(identifier(?)) from values(1,1),(1,2),(2,1),(2,2),(3,1),(3,2) as T(a, b)",
Seq("a"))
assert(df2.collect()(0).getDecimal(0).toString == "0.800000")

val df3 = session
.sql("select * from values (?),(?),(?) as T(id)", Seq(1, 2, 3))
.filter(col("id") < 3)
assert(df3.collect() sameElements Array[Row](Row(1), Row(2)))

val df4 =
session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", Seq(1, 1, 2, 1, 3, 1))
val df5 =
session.sql("select * from values (?,?),(?,?),(?,?) as T(a, b)", List(1, 2, 2, 1, 4, 3))
val df6 = df4.union(df5).filter(col("a") < 3)
assert(df6.collect() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2)))
sfc-gh-evandenberg marked this conversation as resolved.
Show resolved Hide resolved

val df7 = df4.join(df5, Seq("a", "b"), "inner")
assert(df7.collect() sameElements Array[Row](Row(2, 1)))

// Async result
assert(df1.async.collect().getResult() sameElements Array[Row](Row(1), Row(2), Row(3)))
assert(
df6.async.collect().getResult() sameElements Array[Row](Row(1, 1), Row(2, 1), Row(1, 2)))
}
}

class EagerSqlSuite extends SqlSuite with EagerSession {
Expand Down Expand Up @@ -184,6 +215,7 @@ class EagerSqlSuite extends SqlSuite with EagerSession {
assertThrows[SnowflakeSQLException](session.sql("SHOW TABLE"))
}
}

class LazySqlSuite extends SqlSuite with LazySession {
test("Run sql query") {
val df1 = session.sql("select * from values (1),(2),(3)")
Expand Down
Loading