diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java b/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java new file mode 100644 index 0000000000000..5cfcfffd07a28 --- /dev/null +++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.avro.file.DataFileConstants; + +/** + * A mapper class from Spark supported avro compression codecs to avro compression codecs. + */ +public enum AvroCompressionCodec { + UNCOMPRESSED(DataFileConstants.NULL_CODEC), + DEFLATE(DataFileConstants.DEFLATE_CODEC), + SNAPPY(DataFileConstants.SNAPPY_CODEC), + BZIP2(DataFileConstants.BZIP2_CODEC), + XZ(DataFileConstants.XZ_CODEC), + ZSTANDARD(DataFileConstants.ZSTANDARD_CODEC); + + private final String codecName; + + AvroCompressionCodec(String codecName) { + this.codecName = codecName; + } + + public String getCodecName() { + return this.codecName; + } + + private static final Map codecNameMap = + Arrays.stream(AvroCompressionCodec.values()).collect( + Collectors.toMap(codec -> codec.name(), codec -> codec.name().toLowerCase(Locale.ROOT))); + + public String lowerCaseName() { + return codecNameMap.get(this.name()); + } + + public static AvroCompressionCodec fromString(String s) { + return AvroCompressionCodec.valueOf(s.toUpperCase(Locale.ROOT)); + } +} diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 55dea6ed959f8..6a1655a91c918 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -23,7 +23,6 @@ import scala.jdk.CollectionConverters._ import org.apache.avro.Schema import org.apache.avro.file.{DataFileReader, FileReader} -import org.apache.avro.file.DataFileConstants.{BZIP2_CODEC, DEFLATE_CODEC, SNAPPY_CODEC, XZ_CODEC, ZSTANDARD_CODEC} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob @@ -34,6 +33,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.avro.AvroCompressionCodec._ import org.apache.spark.sql.avro.AvroOptions.IGNORE_EXTENSION import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -100,18 +100,19 @@ private[sql] object AvroUtils extends Logging { AvroJob.setOutputKeySchema(job, outputAvroSchema) - if (parsedOptions.compression == "uncompressed") { + if (parsedOptions.compression == UNCOMPRESSED.lowerCaseName()) { job.getConfiguration.setBoolean("mapred.output.compress", false) } else { job.getConfiguration.setBoolean("mapred.output.compress", true) logInfo(s"Compressing Avro output using the ${parsedOptions.compression} codec") - val codec = parsedOptions.compression match { - case DEFLATE_CODEC => + val codec = AvroCompressionCodec.fromString(parsedOptions.compression) match { + case DEFLATE => val deflateLevel = sqlConf.avroDeflateLevel - logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") + logInfo(s"Avro compression level $deflateLevel will be used for " + + s"${DEFLATE.getCodecName()} codec.") job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) - DEFLATE_CODEC - case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC | ZSTANDARD_CODEC) => codec + DEFLATE.getCodecName() + case codec @ (SNAPPY | BZIP2 | XZ | ZSTANDARD) => codec.getCodecName() case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") } job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala index 9a4bcc623b0d4..4e4942e1b2e26 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala @@ -24,6 +24,6 @@ class AvroCodecSuite extends FileSourceCodecSuite { override def format: String = "avro" override val codecConfigName: String = SQLConf.AVRO_COMPRESSION_CODEC.key - override protected def availableCodecs = Seq("uncompressed", "deflate", "snappy", - "bzip2", "xz", "zstandard") + override protected def availableCodecs = + AvroCompressionCodec.values().map(_.lowerCaseName()).iterator.to(Seq) } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 3a25d4d9f709a..64b47e423d669 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -680,18 +680,24 @@ abstract class AvroSuite val zstandardDir = s"$dir/zstandard" val df = spark.read.format("avro").load(testAvro) - spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "uncompressed") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, + AvroCompressionCodec.UNCOMPRESSED.lowerCaseName()) df.write.format("avro").save(uncompressDir) - spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "bzip2") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, + AvroCompressionCodec.BZIP2.lowerCaseName()) df.write.format("avro").save(bzip2Dir) - spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "xz") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, + AvroCompressionCodec.XZ.lowerCaseName()) df.write.format("avro").save(xzDir) - spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "deflate") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, + AvroCompressionCodec.DEFLATE.lowerCaseName()) spark.conf.set(SQLConf.AVRO_DEFLATE_LEVEL.key, "9") df.write.format("avro").save(deflateDir) - spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "snappy") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, + AvroCompressionCodec.SNAPPY.lowerCaseName()) df.write.format("avro").save(snappyDir) - spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, "zstandard") + spark.conf.set(SQLConf.AVRO_COMPRESSION_CODEC.key, + AvroCompressionCodec.ZSTANDARD.lowerCaseName()) df.write.format("avro").save(zstandardDir) val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir))