diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 60685f5c0c6b9..700c05b54a256 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -77,8 +77,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest s"""CREATE TABLE pattern_testing_table ( |pattern_testing_col LONGTEXT |) - """.stripMargin + |""".stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -157,6 +168,79 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 65536) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // MySQL does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } /** diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index d3629d871cd42..54635f69f8b65 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -353,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } @@ -980,4 +980,10 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu ) } } + + def testDatetime(tbl: String): Unit = {} + + test("scan with filter push-down with date time functions") { + testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index f2b626490d13c..785bf5b13aa78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -46,12 +46,33 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions - private val supportedFunctions = supportedAggregateFunctions + private val supportedFunctions = supportedAggregateFunctions ++ Set("DATE_ADD", "DATE_DIFF") override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) class MySQLSQLBuilder extends JDBCSQLBuilder { + override def visitExtract(field: String, source: String): String = { + field match { + case "DAY_OF_YEAR" => s"DAYOFYEAR($source)" + case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)" + // WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ..., + // so we use the formula (WEEKDAY + 1) to follow the ISO standard. + case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)" + case _ => super.visitExtract(field, source) + } + } + + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "DATE_ADD" => + s"DATE_ADD(${inputs(0)}, INTERVAL ${inputs(1)} DAY)" + case "DATE_DIFF" => + s"DATEDIFF(${inputs(0)}, ${inputs(1)})" + case _ => super.visitSQLFunction(funcName, inputs) + } + } + override def visitSortOrder( sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = { (sortDirection, nullOrdering) match {