diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/OrcQuerySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/OrcQuerySuite.scala index e18ab8ba86e..388398c1ad6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/OrcQuerySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/OrcQuerySuite.scala @@ -18,6 +18,8 @@ package com.nvidia.spark.rapids import java.io.File +import scala.collection.mutable.ListBuffer + import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileUtil.fullyDelete @@ -30,6 +32,10 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.rapids.{MyDenseVector, MyDenseVectorUDT} import org.apache.spark.sql.types._ +/** + * This corresponds to the Spark class: + * org.apache.spark.sql.execution.datasources.orc.OrcQueryTest + */ class OrcQuerySuite extends SparkQueryCompareTestSuite { private def getSchema: StructType = new StructType(Array( @@ -151,4 +157,76 @@ class OrcQuerySuite extends SparkQueryCompareTestSuite { assert(encodingKind.toUpperCase.contains("DICTIONARY")) } } + + private def getOrcFileSuffix(compression: String): String = + if (Seq("NONE", "UNCOMPRESSED").contains(compression)) { + ".orc" + } else { + s".${compression.toLowerCase()}.orc" + } + + def checkCompressType(compression: Option[String], orcCompress: Option[String]): Unit = { + withGpuSparkSession { spark => + withTempPath { file => + var writer = spark.range(0, 10).write + writer = compression.map(t => writer.option("compression", t)).getOrElse(writer) + writer = orcCompress.map(t => writer.option("orc.compress", t)).getOrElse(writer) + // write ORC file on GPU + writer.orc(file.getCanonicalPath) + + // expectedType: first use compression, then orc.compress + var expectedType = compression.getOrElse(orcCompress.get) + // ORC use NONE for UNCOMPRESSED + if (expectedType == "UNCOMPRESSED") expectedType = "NONE" + val maybeOrcFile = file.listFiles() + .find(_.getName.endsWith(getOrcFileSuffix(expectedType))) + assert(maybeOrcFile.isDefined) + + // check the compress type using ORC jar + val orcFilePath = new Path(maybeOrcFile.get.getAbsolutePath) + val conf = OrcFile.readerOptions(new Configuration()) + + // the reader is not a AutoCloseable for Spark CDH, so use `withResourceIfAllowed` + // 321cdh uses lower ORC: orc-core-1.5.1.7.1.7.1000-141.jar + // 330cdh uses lower ORC: orc-core-1.5.1.7.1.8.0-801.jar + withResourceIfAllowed(OrcFile.createReader(orcFilePath, conf)) { reader => + // check + assert(expectedType === reader.getCompressionKind.name) + } + } + } + } + + private val supportedWriteCompressTypes = { + // GPU ORC writing does not support ZLIB, LZ4, refer to GpuOrcFileFormat + val supportedWriteCompressType = ListBuffer("UNCOMPRESSED", "NONE", "ZSTD", "SNAPPY") + // Cdh321, Cdh330 does not support ZSTD, refer to the Cdh Class: + // org.apache.spark.sql.execution.datasources.orc.OrcOptions + // Spark 31x do not support lz4, zstd + if (isCdh321 || isCdh330 || !VersionUtils.isSpark320OrLater) { + supportedWriteCompressType -= "ZSTD" + } + supportedWriteCompressType + } + + test("SPARK-16610: Respect orc.compress (i.e., OrcConf.COMPRESS) when compression is unset") { + // Respect `orc.compress` (i.e., OrcConf.COMPRESS). + supportedWriteCompressTypes.foreach { orcCompress => + checkCompressType(None, Some(orcCompress)) + } + + // make pairs, e.g.: [("UNCOMPRESSED", "NONE"), ("NONE", "SNAPPY"), ("SNAPPY", "ZSTD") ... ] + val pairs = supportedWriteCompressTypes.sliding(2).toList.map(pair => (pair.head, pair.last)) + + // "compression" overwrite "orc.compress" + pairs.foreach { case (compression, orcCompress) => + checkCompressType(Some(compression), Some(orcCompress)) + } + } + + test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") { + supportedWriteCompressTypes.foreach { compression => + checkCompressType(Some(compression), None) + } + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index ecdea17e03c..c283bd7afd1 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -2167,4 +2167,8 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll { if (!dirFile.delete()) throw new IOException(s"Delete $dirFile failed!") try func(dirFile) finally FileUtil.fullyDelete(dirFile) } + + def isCdh321: Boolean = VersionUtils.isCloudera && cmpSparkVersion(3, 2, 1) == 0 + + def isCdh330: Boolean = VersionUtils.isCloudera && cmpSparkVersion(3, 3, 0) == 0 }