Skip to content

Commit

Permalink
Upgrade Spark 3.5.1 (#525)
Browse files Browse the repository at this point in the history
---------
Signed-off-by: Peng Huo <[email protected]>
Signed-off-by: Chen Dai <[email protected]>
Co-authored-by: Chen Dai <[email protected]>
  • Loading branch information
penghuo authored Aug 8, 2024
1 parent bcd1942 commit d6e71fa
Show file tree
Hide file tree
Showing 29 changed files with 389 additions and 188 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Version compatibility:
| 0.2.0 | 11+ | 3.3.1 | 2.12.14 | 2.6+ |
| 0.3.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ |
| 0.4.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ |
| 0.5.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ |
| 0.5.0 | 11+ | 3.5.1 | 2.12.14 | 2.13+ |

## Flint Extension Usage

Expand Down
6 changes: 3 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import Dependencies._

lazy val scala212 = "2.12.14"
lazy val sparkVersion = "3.3.2"
// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version hould compatbile
lazy val sparkVersion = "3.5.1"
// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version should compatible
// https://github.com/FasterXML/jackson-module-scala/blob/2.18/src/main/scala/com/fasterxml/jackson/module/scala/JacksonModule.scala#L59
lazy val jacksonVersion = "2.13.4"
lazy val jacksonVersion = "2.15.2"

// The transitive opensearch jackson-databind dependency version should align with Spark jackson databind dependency version.
// Issue: https://github.com/opensearch-project/opensearch-spark/issues/442
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.lang.StringUtils;

import com.amazonaws.services.cloudwatch.model.Dimension;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkEnv;

/**
Expand Down Expand Up @@ -124,4 +124,4 @@ private static Dimension getEnvironmentVariableDimension(String envVarName, Stri
private static Dimension getDefaultDimension(String dimensionName) {
return getEnvironmentVariableDimension(dimensionName, dimensionName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;
import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

package org.apache.spark.sql.flint

import org.apache.spark.internal.Logging
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain

import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand All @@ -17,8 +18,7 @@ case class FlintScan(
options: FlintSparkConf,
pushedPredicates: Array[Predicate])
extends Scan
with Batch
with Logging {
with Batch {

override def readSchema(): StructType = schema

Expand All @@ -44,10 +44,13 @@ case class FlintScan(
* Print pushedPredicates when explain(mode="extended"). Learn from SPARK JDBCScan.
*/
override def description(): String = {
super.description() + ", PushedPredicates: " + seqToString(pushedPredicates)
super.description() + ", PushedPredicates: " + pushedPredicates
.map {
case p if p.name().equalsIgnoreCase(BloomFilterMightContain.NAME) => p.name()
case p => p.toString()
}
.mkString("[", ", ", "]")
}

private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.apache.spark.sql.flint

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain

import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters}
Expand Down Expand Up @@ -34,4 +36,5 @@ case class FlintScanBuilder(
}

override def pushedPredicates(): Array[Predicate] = pushedPredicate
.filterNot(_.name().equalsIgnoreCase(BloomFilterMightContain.NAME))
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ case class FlintWrite(
override def toBatch: BatchWrite = this

override def toStreaming: StreamingWrite = this

override def useCommitCoordinator(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class FlintJacksonParser(
array.toArray[InternalRow](schema)
}
case START_ARRAY =>
throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError()
throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError(
parser.currentToken().asString())
}
}

Expand Down Expand Up @@ -420,17 +421,17 @@ class FlintJacksonParser(
case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString =>
dataType match {
case FloatType | DoubleType | TimestampType | DateType =>
throw QueryExecutionErrors.failToParseEmptyStringForDataTypeError(dataType)
throw QueryExecutionErrors.emptyJsonFieldValueError(dataType)
case _ => null
}

case VALUE_STRING if parser.getTextLength < 1 =>
throw QueryExecutionErrors.failToParseEmptyStringForDataTypeError(dataType)
throw QueryExecutionErrors.emptyJsonFieldValueError(dataType)

case token =>
// We cannot parse this token based on the given data type. So, we throw a
// RuntimeException and this exception will be caught by `parse` method.
throw QueryExecutionErrors.failToParseValueForDataTypeError(parser, token, dataType)
throw QueryExecutionErrors.cannotParseJSONFieldError(parser, token, dataType)
}

/**
Expand Down Expand Up @@ -537,19 +538,19 @@ class FlintJacksonParser(
// JSON parser currently doesn't support partial results for corrupted records.
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => recordLiteral(record), () => None, e)
throw BadRecordException(() => recordLiteral(record), cause = e)
case e: CharConversionException if options.encoding.isEmpty =>
val msg =
"""JSON parser cannot handle a character in its input.
|Specifying encoding as an input option explicitly might help to resolve the issue.
|""".stripMargin + e.getMessage
val wrappedCharException = new CharConversionException(msg)
wrappedCharException.initCause(e)
throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException)
throw BadRecordException(() => recordLiteral(record), cause = wrappedCharException)
case PartialResultException(row, cause) =>
throw BadRecordException(
record = () => recordLiteral(record),
partialResult = () => Some(row),
partialResults = () => Array(row),
cause)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.concurrent.ScheduledExecutorService

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.internal.SQLConf.DEFAULT_CATALOG
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.{ShutdownHookManager, ThreadUtils}

Expand Down Expand Up @@ -72,14 +73,8 @@ package object flint {
def qualifyTableName(spark: SparkSession, tableName: String): String = {
val (catalog, ident) = parseTableName(spark, tableName)

// Tricky that our Flint delegate catalog's name has to be spark_catalog
// so we have to find its actual name in CatalogManager
val catalogMgr = spark.sessionState.catalogManager
val catalogName =
catalogMgr
.listCatalogs(Some("*"))
.find(catalogMgr.catalog(_) == catalog)
.getOrElse(catalog.name())
// more reading at https://github.com/opensearch-project/opensearch-spark/issues/319.
val catalogName = resolveCatalogName(spark, catalog)

s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}"
}
Expand Down Expand Up @@ -134,4 +129,41 @@ package object flint {
def findField(rootField: StructType, fieldName: String): Option[StructField] = {
rootField.findNestedField(fieldName.split('.')).map(_._2)
}

/**
* Resolve catalog name. spark.sql.defaultCatalog name is returned if catalog.name is
* spark_catalog otherwise, catalog.name is returned.
* @see
* <a href="https://github.com/opensearch-project/opensearch-spark/issues/319#issuecomment
* -2099630984">issue319</a>
*
* @param spark
* Spark Session
* @param catalog
* Spark Catalog
* @return
* catalog name.
*/
def resolveCatalogName(spark: SparkSession, catalog: CatalogPlugin): String = {

/**
* Check if the provided catalog is a session catalog.
*/
if (CatalogV2Util.isSessionCatalog(catalog)) {
val defaultCatalog = spark.conf.get(DEFAULT_CATALOG)
if (spark.sessionState.catalogManager.isCatalogRegistered(defaultCatalog)) {
defaultCatalog
} else {

/**
* It may happen when spark.sql.defaultCatalog is configured, but there's no
* implementation. For instance, spark.sql.defaultCatalog = "unknown"
*/
throw new IllegalStateException(s"Unknown catalog name: $defaultCatalog")
}
} else {
// Return the name for non-session catalogs
catalog.name()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COL

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusWithMetadata, PartitionDirectory}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.functions.isnull
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -96,7 +96,7 @@ case class FlintSparkSkippingFileIndex(
.toSet
}

private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = {
private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatusWithMetadata) = {
selectedFiles.contains(f.getPath.toUri.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKi

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StringType

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -115,6 +118,17 @@ object FlintSparkSkippingStrategy {
Seq(attr.name)
case GetStructField(child, _, Some(name)) =>
extractColumnName(child) :+ name
/**
* Since Spark 3.4 add read-side padding, char_col = "sample char" became
* (staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils,
* StringType, readSidePadding, char_col#47, 20, true, false, true) = sample char )
*
* When create skipping index, Spark did write-side padding. So read-side push down can be
* ignored. More reading, https://issues.apache.org/jira/browse/SPARK-40697
*/
case StaticInvoke(staticObject, StringType, "readSidePadding", arguments, _, _, _, _)
if classOf[CharVarcharCodegenUtils].isAssignableFrom(staticObject) =>
extractColumnName(arguments.head)
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.opensearch.flint.spark.skipping.bloomfilter
import java.io.ByteArrayInputStream

import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.NAME

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -40,7 +41,7 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre

override def dataType: DataType = BooleanType

override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN"
override def symbol: String = NAME

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
Expand Down Expand Up @@ -109,6 +110,8 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre

object BloomFilterMightContain {

val NAME = "BLOOM_FILTER_MIGHT_CONTAIN"

/**
* Generate bloom filter might contain function given the bloom filter column and value.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.mockito.Mockito.when
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin}
import org.apache.spark.sql.flint.resolveCatalogName
import org.apache.spark.sql.internal.{SessionState, SQLConf}
import org.apache.spark.sql.internal.SQLConf.DEFAULT_CATALOG

class FlintCatalogSuite extends SparkFunSuite with MockitoSugar {

test("resolveCatalogName returns default catalog name for session catalog") {
assertCatalog()
.withCatalogName("spark_catalog")
.withDefaultCatalog("glue")
.registerCatalog("glue")
.shouldResolveCatalogName("glue")
}

test("resolveCatalogName returns default catalog name for spark_catalog") {
assertCatalog()
.withCatalogName("spark_catalog")
.withDefaultCatalog("spark_catalog")
.registerCatalog("spark_catalog")
.shouldResolveCatalogName("spark_catalog")
}

test("resolveCatalogName should return catalog name for non-session catalogs") {
assertCatalog()
.withCatalogName("custom_catalog")
.withDefaultCatalog("custom_catalog")
.registerCatalog("custom_catalog")
.shouldResolveCatalogName("custom_catalog")
}

test(
"resolveCatalogName should throw RuntimeException when default catalog is not registered") {
assertCatalog()
.withCatalogName("spark_catalog")
.withDefaultCatalog("glue")
.registerCatalog("unknown")
.shouldThrowException()
}

private def assertCatalog(): AssertionHelper = {
new AssertionHelper
}

private class AssertionHelper {
private val spark = mock[SparkSession]
private val catalog = mock[CatalogPlugin]
private val sessionState = mock[SessionState]
private val catalogManager = mock[CatalogManager]

def withCatalogName(catalogName: String): AssertionHelper = {
when(catalog.name()).thenReturn(catalogName)
this
}

def withDefaultCatalog(catalogName: String): AssertionHelper = {
val conf = new SQLConf
conf.setConf(DEFAULT_CATALOG, catalogName)
when(spark.conf).thenReturn(new RuntimeConfig(conf))
this
}

def registerCatalog(catalogName: String): AssertionHelper = {
when(spark.sessionState).thenReturn(sessionState)
when(sessionState.catalogManager).thenReturn(catalogManager)
when(catalogManager.isCatalogRegistered(catalogName)).thenReturn(true)
this
}

def shouldResolveCatalogName(expectedCatalogName: String): Unit = {
assert(resolveCatalogName(spark, catalog) == expectedCatalogName)
}

def shouldThrowException(): Unit = {
assertThrows[IllegalStateException] {
resolveCatalogName(spark, catalog)
}
}
}
}
Loading

0 comments on commit d6e71fa

Please sign in to comment.