Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test cases for ORC writing according to options orc.compress and compression [databricks] #8785

Merged
merged 6 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.orc</groupId>
<artifactId>orc-core</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
</profile>
<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package com.nvidia.spark.rapids

import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.FileUtils.withTempPath
import org.apache.spark.sql.rapids.GpuFileSourceScanExec

class GpuFileScanPrunePartitionSuite extends SparkQueryCompareTestSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import com.nvidia.spark.rapids.RapidsReaderType._
import com.nvidia.spark.rapids.shims.GpuBatchScanExec

import org.apache.spark.SparkConf
import org.apache.spark.sql.FileUtils.withTempPath
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.rapids.{ExternalSource, GpuFileSourceScanExec}
Expand Down
82 changes: 82 additions & 0 deletions tests/src/test/scala/com/nvidia/spark/rapids/OrcQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@ package com.nvidia.spark.rapids

import java.io.File

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.Arm.withResourceIfAllowed
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileUtil.fullyDelete
import org.apache.hadoop.fs.Path
import org.apache.orc.OrcFile

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.rapids.{MyDenseVector, MyDenseVectorUDT}
import org.apache.spark.sql.types._

/**
* This corresponds to the Spark class:
* org.apache.spark.sql.execution.datasources.orc.OrcQueryTest
*/
class OrcQuerySuite extends SparkQueryCompareTestSuite {

private def getSchema: StructType = new StructType(Array(
Expand Down Expand Up @@ -92,4 +102,76 @@ class OrcQuerySuite extends SparkQueryCompareTestSuite {
) {
frame => frame
}

private def getOrcFilePostfix(compression: String): String =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private def getOrcFilePostfix(compression: String): String =
private def getOrcFileSuffix(compression: String): String =

if (Seq("NONE", "UNCOMPRESSED").contains(compression)) {
".orc"
} else {
s".${compression.toLowerCase()}.orc"
}

def checkCompressType(compression: Option[String], orcCompress: Option[String]): Unit = {
withGpuSparkSession { spark =>
withTempPath { file =>
var writer = spark.range(0, 10).write
writer = compression.map(t => writer.option("compression", t)).getOrElse(writer)
writer = orcCompress.map(t => writer.option("orc.compress", t)).getOrElse(writer)
// write ORC file on GPU
writer.orc(file.getCanonicalPath)

// expectedType: first use compression, then orc.compress
var expectedType = compression.getOrElse(orcCompress.get)
// ORC use NONE for UNCOMPRESSED
if (expectedType == "UNCOMPRESSED") expectedType = "NONE"
val maybeOrcFile = file.listFiles()
.find(_.getName.endsWith(getOrcFilePostfix(expectedType)))
assert(maybeOrcFile.isDefined)

// check the compress type using ORC jar
val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath)
val conf = OrcFile.readerOptions(new Configuration())

// the reader is not a AutoCloseable for Spark CDH, so use `withResourceIfAllowed`
// 321cdh uses lower ORC: orc-core-1.5.1.7.1.7.1000-141.jar
// 330cdh uses lower ORC: orc-core-1.5.1.7.1.8.0-801.jar
withResourceIfAllowed(OrcFile.createReader(orcFilePath, conf)) { reader =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a useful utility. It makes the code quite streamlined.

// check
assert(expectedType === reader.getCompressionKind.name)
}
}
}
}

private val supportedWriteCompressTypes = {
// GPU ORC writing does not support ZLIB, LZ4, refer to GpuOrcFileFormat
val supportedWriteCompressType = ListBuffer("UNCOMPRESSED", "NONE", "ZSTD", "SNAPPY")
// Cdh321, Cdh330 does not support ZSTD, refer to the Cdh Class:
// org.apache.spark.sql.execution.datasources.orc.OrcOptions
// Spark 31x do not support lz4, zstd
if (isCdh321 || isCdh330 || !VersionUtils.isSpark320OrLater) {
supportedWriteCompressType -= "ZSTD"
}
supportedWriteCompressType
}

test("SPARK-16610: Respect orc.compress (i.e., OrcConf.COMPRESS) when compression is unset") {
// Respect `orc.compress` (i.e., OrcConf.COMPRESS).
supportedWriteCompressTypes.foreach { orcCompress =>
checkCompressType(None, Some(orcCompress))
}

// make paris, e.g.: [("UNCOMPRESSED", "NONE"), ("NONE", "SNAPPY"), ("SNAPPY", "ZSTD") ... ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// make paris, e.g.: [("UNCOMPRESSED", "NONE"), ("NONE", "SNAPPY"), ("SNAPPY", "ZSTD") ... ]
// make pairs, e.g.: [("UNCOMPRESSED", "NONE"), ("NONE", "SNAPPY"), ("SNAPPY", "ZSTD") ... ]

val pairs = supportedWriteCompressTypes.sliding(2).toList.map(pair => (pair.head, pair.last))

// "compression" overwrite "orc.compress"
pairs.foreach { case (compression, orcCompress) =>
checkCompressType(Some(compression), Some(orcCompress))
}
}

test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") {
supportedWriteCompressTypes.foreach { compression =>
checkCompressType(Some(compression), None)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
*/
package com.nvidia.spark.rapids

import java.io.File
import java.io.{File, IOException}
import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.util.{Locale, TimeZone}
import java.util.{Locale, TimeZone, UUID}

import org.scalatest.{Assertion, BeforeAndAfterAll}
import org.scalatest.funsuite.AnyFunSuite
import scala.reflect.ClassTag
import scala.util.{Failure, Try}

import org.apache.hadoop.fs.FileUtil
import org.scalatest.{Assertion, BeforeAndAfterAll}
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand Down Expand Up @@ -2157,4 +2159,16 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll {
false
}
}

def withTempPath[B](func: File => B): B = {
val rootTmpDir = System.getProperty("java.io.tmpdir")
val dirFile = new File(rootTmpDir, "spark-test-" + UUID.randomUUID)
Files.createDirectories(dirFile.toPath)
if (!dirFile.delete()) throw new IOException(s"Delete $dirFile failed!")
try func(dirFile) finally FileUtil.fullyDelete(dirFile)
}

def isCdh321: Boolean = VersionUtils.isCloudera && cmpSparkVersion(3, 2, 1) == 0

def isCdh330: Boolean = VersionUtils.isCloudera && cmpSparkVersion(3, 3, 0) == 0
}