Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Spark 3.5.1 #525

Merged
merged 13 commits into from
Aug 8, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Version compatibility:
| 0.2.0 | 11+ | 3.3.1 | 2.12.14 | 2.6+ |
| 0.3.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ |
| 0.4.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ |
| 0.5.0 | 11+ | 3.3.2 | 2.12.14 | 2.13+ |
| 0.5.0 | 11+ | 3.5.1 | 2.12.14 | 2.13+ |

## Flint Extension Usage

Expand Down
6 changes: 3 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import Dependencies._

lazy val scala212 = "2.12.14"
lazy val sparkVersion = "3.3.2"
// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version hould compatbile
lazy val sparkVersion = "3.5.1"
// Spark jackson version. Spark jackson-module-scala strictly check the jackson-databind version should compatible
// https://github.com/FasterXML/jackson-module-scala/blob/2.18/src/main/scala/com/fasterxml/jackson/module/scala/JacksonModule.scala#L59
lazy val jacksonVersion = "2.13.4"
lazy val jacksonVersion = "2.15.2"

// The transitive opensearch jackson-databind dependency version should align with Spark jackson databind dependency version.
// Issue: https://github.com/opensearch-project/opensearch-spark/issues/442
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.lang.StringUtils;

import com.amazonaws.services.cloudwatch.model.Dimension;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkEnv;

/**
Expand Down Expand Up @@ -124,4 +124,4 @@ private static Dimension getEnvironmentVariableDimension(String envVarName, Stri
private static Dimension getDefaultDimension(String dimensionName) {
return getEnvironmentVariableDimension(dimensionName, dimensionName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;
import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

package org.apache.spark.sql.flint

import org.apache.spark.internal.Logging
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain

import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.flint.config.FlintSparkConf
Expand All @@ -17,8 +18,7 @@ case class FlintScan(
options: FlintSparkConf,
pushedPredicates: Array[Predicate])
extends Scan
with Batch
with Logging {
with Batch {

override def readSchema(): StructType = schema

Expand All @@ -44,10 +44,13 @@ case class FlintScan(
* Print pushedPredicates when explain(mode="extended"). Learn from SPARK JDBCScan.
*/
override def description(): String = {
super.description() + ", PushedPredicates: " + seqToString(pushedPredicates)
super.description() + ", PushedPredicates: " + pushedPredicates
.map {
case p if p.name().equalsIgnoreCase(BloomFilterMightContain.NAME) => p.name()
case p => p.toString()
}
.mkString("[", ", ", "]")
}

private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.apache.spark.sql.flint

import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain

import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters}
Expand Down Expand Up @@ -34,4 +36,5 @@ case class FlintScanBuilder(
}

override def pushedPredicates(): Array[Predicate] = pushedPredicate
.filterNot(_.name().equalsIgnoreCase(BloomFilterMightContain.NAME))
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,6 @@ case class FlintWrite(
override def toBatch: BatchWrite = this

override def toStreaming: StreamingWrite = this

override def useCommitCoordinator(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class FlintJacksonParser(
array.toArray[InternalRow](schema)
}
case START_ARRAY =>
throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError()
throw QueryExecutionErrors.cannotParseJsonArraysAsStructsError(
parser.currentToken().asString())
}
}

Expand Down Expand Up @@ -420,17 +421,17 @@ class FlintJacksonParser(
case VALUE_STRING if parser.getTextLength < 1 && allowEmptyString =>
dataType match {
case FloatType | DoubleType | TimestampType | DateType =>
throw QueryExecutionErrors.failToParseEmptyStringForDataTypeError(dataType)
throw QueryExecutionErrors.emptyJsonFieldValueError(dataType)
case _ => null
}

case VALUE_STRING if parser.getTextLength < 1 =>
throw QueryExecutionErrors.failToParseEmptyStringForDataTypeError(dataType)
throw QueryExecutionErrors.emptyJsonFieldValueError(dataType)

case token =>
// We cannot parse this token based on the given data type. So, we throw a
// RuntimeException and this exception will be caught by `parse` method.
throw QueryExecutionErrors.failToParseValueForDataTypeError(parser, token, dataType)
throw QueryExecutionErrors.cannotParseJSONFieldError(parser, token, dataType)
}

/**
Expand Down Expand Up @@ -537,19 +538,19 @@ class FlintJacksonParser(
// JSON parser currently doesn't support partial results for corrupted records.
// For such records, all fields other than the field configured by
// `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => recordLiteral(record), () => None, e)
throw BadRecordException(() => recordLiteral(record), cause = e)
case e: CharConversionException if options.encoding.isEmpty =>
val msg =
"""JSON parser cannot handle a character in its input.
|Specifying encoding as an input option explicitly might help to resolve the issue.
|""".stripMargin + e.getMessage
val wrappedCharException = new CharConversionException(msg)
wrappedCharException.initCause(e)
throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException)
throw BadRecordException(() => recordLiteral(record), cause = wrappedCharException)
case PartialResultException(row, cause) =>
throw BadRecordException(
record = () => recordLiteral(record),
partialResult = () => Some(row),
partialResults = () => Array(row),
cause)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.concurrent.ScheduledExecutorService

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.internal.SQLConf.DEFAULT_CATALOG
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.{ShutdownHookManager, ThreadUtils}

Expand Down Expand Up @@ -72,14 +73,8 @@ package object flint {
def qualifyTableName(spark: SparkSession, tableName: String): String = {
val (catalog, ident) = parseTableName(spark, tableName)

// Tricky that our Flint delegate catalog's name has to be spark_catalog
// so we have to find its actual name in CatalogManager
val catalogMgr = spark.sessionState.catalogManager
val catalogName =
catalogMgr
.listCatalogs(Some("*"))
.find(catalogMgr.catalog(_) == catalog)
.getOrElse(catalog.name())
// more reading at https://github.com/opensearch-project/opensearch-spark/issues/319.
val catalogName = resolveCatalogName(spark, catalog)

s"$catalogName.${ident.namespace.mkString(".")}.${ident.name}"
}
Expand Down Expand Up @@ -134,4 +129,41 @@ package object flint {
def findField(rootField: StructType, fieldName: String): Option[StructField] = {
rootField.findNestedField(fieldName.split('.')).map(_._2)
}

/**
* Resolve catalog name. spark.sql.defaultCatalog name is returned if catalog.name is
* spark_catalog otherwise, catalog.name is returned.
* @see
* <a href="https://github.com/opensearch-project/opensearch-spark/issues/319#issuecomment
* -2099630984">issue319</a>
*
* @param spark
* Spark Session
* @param catalog
* Spark Catalog
* @return
* catalog name.
*/
def resolveCatalogName(spark: SparkSession, catalog: CatalogPlugin): String = {

/**
* Check if the provided catalog is a session catalog.
*/
if (CatalogV2Util.isSessionCatalog(catalog)) {
val defaultCatalog = spark.conf.get(DEFAULT_CATALOG)
if (spark.sessionState.catalogManager.isCatalogRegistered(defaultCatalog)) {
defaultCatalog
} else {

/**
* It may happen when spark.sql.defaultCatalog is configured, but there's no
* implementation. For instance, spark.sql.defaultCatalog = "unknown"
*/
throw new IllegalStateException(s"Unknown catalog name: $defaultCatalog")
}
} else {
// Return the name for non-session catalogs
catalog.name()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COL

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.execution.datasources.{FileIndex, FileStatusWithMetadata, PartitionDirectory}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.functions.isnull
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -96,7 +96,7 @@ case class FlintSparkSkippingFileIndex(
.toSet
}

private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = {
private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatusWithMetadata) = {
selectedFiles.contains(f.getPath.toUri.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKi

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GetStructField}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StringType

/**
* Skipping index strategy that defines skipping data structure building and reading logic.
Expand Down Expand Up @@ -115,6 +118,17 @@ object FlintSparkSkippingStrategy {
Seq(attr.name)
case GetStructField(child, _, Some(name)) =>
extractColumnName(child) :+ name
/**
* Since Spark 3.4 add read-side padding, char_col = "sample char" became
* (staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils,
* StringType, readSidePadding, char_col#47, 20, true, false, true) = sample char )
*
* When create skipping index, Spark did write-side padding. So read-side push down can be
* ignored. More reading, https://issues.apache.org/jira/browse/SPARK-40697
*/
case StaticInvoke(staticObject, StringType, "readSidePadding", arguments, _, _, _, _)
if classOf[CharVarcharCodegenUtils].isAssignableFrom(staticObject) =>
extractColumnName(arguments.head)
case _ => Seq.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.opensearch.flint.spark.skipping.bloomfilter
import java.io.ByteArrayInputStream

import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter
import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterMightContain.NAME

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -40,7 +41,7 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre

override def dataType: DataType = BooleanType

override def symbol: String = "BLOOM_FILTER_MIGHT_CONTAIN"
override def symbol: String = NAME

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
Expand Down Expand Up @@ -109,6 +110,8 @@ case class BloomFilterMightContain(bloomFilterExpression: Expression, valueExpre

object BloomFilterMightContain {

val NAME = "BLOOM_FILTER_MIGHT_CONTAIN"

/**
* Generate bloom filter might contain function given the bloom filter column and value.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql

import org.mockito.Mockito.when
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin}
import org.apache.spark.sql.flint.resolveCatalogName
import org.apache.spark.sql.internal.{SessionState, SQLConf}
import org.apache.spark.sql.internal.SQLConf.DEFAULT_CATALOG

class FlintCatalogSuite extends SparkFunSuite with MockitoSugar {

test("resolveCatalogName returns default catalog name for session catalog") {
assertCatalog()
.withCatalogName("spark_catalog")
.withDefaultCatalog("glue")
.registerCatalog("glue")
.shouldResolveCatalogName("glue")
}

test("resolveCatalogName returns default catalog name for spark_catalog") {
assertCatalog()
.withCatalogName("spark_catalog")
.withDefaultCatalog("spark_catalog")
.registerCatalog("spark_catalog")
.shouldResolveCatalogName("spark_catalog")
}

test("resolveCatalogName should return catalog name for non-session catalogs") {
assertCatalog()
.withCatalogName("custom_catalog")
.withDefaultCatalog("custom_catalog")
.registerCatalog("custom_catalog")
.shouldResolveCatalogName("custom_catalog")
}

test(
"resolveCatalogName should throw RuntimeException when default catalog is not registered") {
assertCatalog()
.withCatalogName("spark_catalog")
.withDefaultCatalog("glue")
.registerCatalog("unknown")
.shouldThrowException()
}

private def assertCatalog(): AssertionHelper = {
new AssertionHelper
}

private class AssertionHelper {
private val spark = mock[SparkSession]
private val catalog = mock[CatalogPlugin]
private val sessionState = mock[SessionState]
private val catalogManager = mock[CatalogManager]

def withCatalogName(catalogName: String): AssertionHelper = {
when(catalog.name()).thenReturn(catalogName)
this
}

def withDefaultCatalog(catalogName: String): AssertionHelper = {
val conf = new SQLConf
conf.setConf(DEFAULT_CATALOG, catalogName)
when(spark.conf).thenReturn(new RuntimeConfig(conf))
this
}

def registerCatalog(catalogName: String): AssertionHelper = {
when(spark.sessionState).thenReturn(sessionState)
when(sessionState.catalogManager).thenReturn(catalogManager)
when(catalogManager.isCatalogRegistered(catalogName)).thenReturn(true)
this
}

def shouldResolveCatalogName(expectedCatalogName: String): Unit = {
assert(resolveCatalogName(spark, catalog) == expectedCatalogName)
}

def shouldThrowException(): Unit = {
assertThrows[IllegalStateException] {
resolveCatalogName(spark, catalog)
}
}
}
}
Loading
Loading