Skip to content

Commit

Permalink
[SPARK-49488][SQL] MySQL dialect supports pushdown datetime functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR propose to make MySQL dialect supports pushdown datetime functions.

### Why are the changes needed?
Currently, DS V2 pushdown framework pushed the datetime functions with in a common way. But MySQL doesn't support some datetime functions.

### Does this PR introduce _any_ user-facing change?
'No'.
This is a new feature for MySQL dialect.

### How was this patch tested?
GA.

### Was this patch authored or co-authored using generative AI tooling?
'No'.

Closes #47951 from beliefer/SPARK-49488.

Authored-by: beliefer <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
beliefer authored and cloud-fan committed Sep 13, 2024
1 parent 5533c81 commit 9fc58aa
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
""".stripMargin
|""".stripMargin
).executeUpdate()
connection.prepareStatement(
"CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
.executeUpdate()
}

override def dataPreparation(connection: Connection): Unit = {
super.dataPreparation(connection)
connection.prepareStatement("INSERT INTO datetime VALUES " +
"('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate()
connection.prepareStatement("INSERT INTO datetime VALUES " +
"('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
Expand Down Expand Up @@ -157,6 +168,79 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 65536)
}
}

override def testDatetime(tbl: String): Unit = {
val df1 = sql(s"SELECT name FROM $tbl WHERE " +
"dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ")
checkFilterPushed(df1)
val rows1 = df1.collect()
assert(rows1.length === 2)
assert(rows1(0).getString(0) === "amy")
assert(rows1(1).getString(0) === "alex")

val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2")
checkFilterPushed(df2)
val rows2 = df2.collect()
assert(rows2.length === 2)
assert(rows2(0).getString(0) === "amy")
assert(rows2(1).getString(0) === "alex")

val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5")
checkFilterPushed(df3)
val rows3 = df3.collect()
assert(rows3.length === 2)
assert(rows3(0).getString(0) === "amy")
assert(rows3(1).getString(0) === "alex")

val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0")
checkFilterPushed(df4)
val rows4 = df4.collect()
assert(rows4.length === 2)
assert(rows4(0).getString(0) === "amy")
assert(rows4(1).getString(0) === "alex")

val df5 = sql(s"SELECT name FROM $tbl WHERE " +
"extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022")
checkFilterPushed(df5)
val rows5 = df5.collect()
assert(rows5.length === 2)
assert(rows5(0).getString(0) === "amy")
assert(rows5(1).getString(0) === "alex")

val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " +
"AND datediff(date1, '2022-05-10') > 0")
checkFilterPushed(df6)
val rows6 = df6.collect()
assert(rows6.length === 1)
assert(rows6(0).getString(0) === "amy")

val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2")
checkFilterPushed(df7)
val rows7 = df7.collect()
assert(rows7.length === 1)
assert(rows7(0).getString(0) === "alex")

val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4")
checkFilterPushed(df8)
val rows8 = df8.collect()
assert(rows8.length === 1)
assert(rows8(0).getString(0) === "alex")

val df9 = sql(s"SELECT name FROM $tbl WHERE " +
"dayofyear(date1) > 100 order by dayofyear(date1) limit 1")
checkFilterPushed(df9)
val rows9 = df9.collect()
assert(rows9.length === 1)
assert(rows9(0).getString(0) === "alex")

// MySQL does not support
val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'")
checkFilterPushed(df10, false)
val rows10 = df10.collect()
assert(rows10.length === 2)
assert(rows10(0).getString(0) === "amy")
assert(rows10(1).getString(0) === "alex")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}

private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
val filter = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
Expand Down Expand Up @@ -980,4 +980,10 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
)
}
}

def testDatetime(tbl: String): Unit = {}

test("scan with filter push-down with date time functions") {
testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,33 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
// See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
private val supportedAggregateFunctions =
Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions ++ Set("DATE_ADD", "DATE_DIFF")

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

class MySQLSQLBuilder extends JDBCSQLBuilder {
override def visitExtract(field: String, source: String): String = {
field match {
case "DAY_OF_YEAR" => s"DAYOFYEAR($source)"
case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)"
// WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
// so we use the formula (WEEKDAY + 1) to follow the ISO standard.
case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)"
case _ => super.visitExtract(field, source)
}
}

override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
funcName match {
case "DATE_ADD" =>
s"DATE_ADD(${inputs(0)}, INTERVAL ${inputs(1)} DAY)"
case "DATE_DIFF" =>
s"DATEDIFF(${inputs(0)}, ${inputs(1)})"
case _ => super.visitSQLFunction(funcName, inputs)
}
}

override def visitSortOrder(
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
(sortDirection, nullOrdering) match {
Expand Down

0 comments on commit 9fc58aa

Please sign in to comment.