Skip to content

Commit

Permalink
[SPARK-49695][SQL] Postgres fix xor push-down
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR fixes the pushdown of ^ operator (XOR operator) for Postgres. Those two databases use this as exponent, rather then bitwise xor.

Fix is consisted of overriding the SQLExpressionBuilder to replace the '^' character with '#'.
### Why are the changes needed?
Result is incorrect.

### Does this PR introduce _any_ user-facing change?
Yes. The user will now have a proper translation of the ^ operator.

### How was this patch tested?

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48144 from andrej-db/SPARK-49695-PostgresXOR.

Lead-authored-by: Andrej Gobeljić <[email protected]>
Co-authored-by: andrej-db <[email protected]>
Co-authored-by: andrej-gobeljic_data <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
andrej-db authored and MaxGekk committed Dec 4, 2024
1 parent f1eecd3 commit 4248397
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Connection
import org.apache.spark.{SparkConf, SparkSQLException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.jdbc.PostgresDatabaseOnDocker
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -243,6 +244,15 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
}
}

test("SPARK-49695: Postgres fix xor push-down") {
val df = spark.sql(s"select dept, name from $catalogName.employee where dept ^ 6 = 0")
val rows = df.collect()
assert(!df.queryExecution.sparkPlan.exists(_.isInstanceOf[FilterExec]))
assert(rows.length == 1)
assert(rows(0).getInt(0) === 6)
assert(rows(0).getString(1) === "jen")
}

override def testDatetime(tbl: String): Unit = {
val df1 = sql(s"SELECT name FROM $tbl WHERE " +
"dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ private case class PostgresDialect()
case _ => super.visitExtract(field, source)
}
}

override def visitBinaryArithmetic(name: String, l: String, r: String): String = {
l + " " + name.replace('^', '#') + " " + r
}
}

override def compileExpression(expr: Expression): Option[String] = {
Expand Down

0 comments on commit 4248397

Please sign in to comment.