diff --git a/velox/connectors/hive/HiveDataSink.cpp b/velox/connectors/hive/HiveDataSink.cpp index aead9d3381e4..a936580e4860 100644 --- a/velox/connectors/hive/HiveDataSink.cpp +++ b/velox/connectors/hive/HiveDataSink.cpp @@ -677,57 +677,106 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { ioStats_.emplace_back(std::make_shared()); setMemoryReclaimers(writerInfo_.back().get(), ioStats_.back().get()); - dwio::common::WriterOptions options; + // Take the one provided by the user as a starting point, or allocate a new + // one. + const auto& writerOptions = insertTableHandle_->writerOptions(); + std::shared_ptr options = writerOptions + ? writerOptions + : std::make_shared(); + const auto* connectorSessionProperties = connectorQueryCtx_->sessionProperties(); - options.schema = getNonPartitionTypes(dataChannels_, inputType_); - - options.memoryPool = writerInfo_.back()->writerPool.get(); - options.compressionKind = insertTableHandle_->compressionKind(); - if (canReclaim()) { - options.spillConfig = spillConfig_; - } - options.nonReclaimableSection = - writerInfo_.back()->nonReclaimableSectionHolder.get(); - options.maxStripeSize = std::optional( - hiveConfig_->orcWriterMaxStripeSize(connectorSessionProperties)); - options.maxDictionaryMemory = std::optional( - hiveConfig_->orcWriterMaxDictionaryMemory(connectorSessionProperties)); - options.orcWriterIntegerDictionaryEncodingEnabled = - hiveConfig_->isOrcWriterIntegerDictionaryEncodingEnabled( - connectorSessionProperties); - options.orcWriterStringDictionaryEncodingEnabled = - hiveConfig_->isOrcWriterStringDictionaryEncodingEnabled( - connectorSessionProperties); - options.parquetWriteTimestampUnit = - hiveConfig_->parquetWriteTimestampUnit(connectorSessionProperties); - options.orcMinCompressionSize = std::optional( - hiveConfig_->orcWriterMinCompressionSize(connectorSessionProperties)); - options.orcLinearStripeSizeHeuristics = - std::optional(hiveConfig_->orcWriterLinearStripeSizeHeuristics( - connectorSessionProperties)); - options.serdeParameters = std::map( - insertTableHandle_->serdeParameters().begin(), - insertTableHandle_->serdeParameters().end()); + + if (!options->schema) { + options->schema = getNonPartitionTypes(dataChannels_, inputType_); + } + + if (!options->memoryPool) { + options->memoryPool = writerInfo_.back()->writerPool.get(); + } + + if (!options->compressionKind) { + options->compressionKind = insertTableHandle_->compressionKind(); + } + + if (!options->spillConfig && canReclaim()) { + options->spillConfig = spillConfig_; + } + + if (!options->nonReclaimableSection) { + options->nonReclaimableSection = + writerInfo_.back()->nonReclaimableSectionHolder.get(); + } + + if (!options->maxStripeSize) { + options->maxStripeSize = std::optional( + hiveConfig_->orcWriterMaxStripeSize(connectorSessionProperties)); + } + + if (!options->maxDictionaryMemory) { + options->maxDictionaryMemory = std::optional( + hiveConfig_->orcWriterMaxDictionaryMemory(connectorSessionProperties)); + } + + if (!options->orcWriterIntegerDictionaryEncodingEnabled) { + options->orcWriterIntegerDictionaryEncodingEnabled = + hiveConfig_->isOrcWriterIntegerDictionaryEncodingEnabled( + connectorSessionProperties); + } + + if (!options->orcWriterStringDictionaryEncodingEnabled) { + options->orcWriterStringDictionaryEncodingEnabled = + hiveConfig_->isOrcWriterStringDictionaryEncodingEnabled( + connectorSessionProperties); + } + + if (!options->parquetWriteTimestampUnit) { + options->parquetWriteTimestampUnit = + hiveConfig_->parquetWriteTimestampUnit(connectorSessionProperties); + } + + if (!options->orcMinCompressionSize) { + options->orcMinCompressionSize = std::optional( + hiveConfig_->orcWriterMinCompressionSize(connectorSessionProperties)); + } + + if (!options->orcLinearStripeSizeHeuristics) { + options->orcLinearStripeSizeHeuristics = + std::optional(hiveConfig_->orcWriterLinearStripeSizeHeuristics( + connectorSessionProperties)); + } + + if (options->serdeParameters.empty()) { + options->serdeParameters = std::map( + insertTableHandle_->serdeParameters().begin(), + insertTableHandle_->serdeParameters().end()); + } auto compressionLevel = hiveConfig_->orcWriterCompressionLevel(connectorSessionProperties); - options.zlibCompressionLevel = - compressionLevel.value_or(kDefaultZlibCompressionLevel); - options.zstdCompressionLevel = - compressionLevel.value_or(kDefaultZstdCompressionLevel); + + if (!options->zlibCompressionLevel) { + options->zlibCompressionLevel = + compressionLevel.value_or(kDefaultZlibCompressionLevel); + } + if (!options->zstdCompressionLevel) { + options->zstdCompressionLevel = + compressionLevel.value_or(kDefaultZstdCompressionLevel); + } // Prevents the memory allocation during the writer creation. WRITER_NON_RECLAIMABLE_SECTION_GUARD(writerInfo_.size() - 1); auto writer = writerFactory_->createWriter( dwio::common::FileSink::create( writePath, - {.bufferWrite = false, - .connectorProperties = hiveConfig_->config(), - .fileCreateConfig = hiveConfig_->writeFileCreateConfig(), - .pool = writerInfo_.back()->sinkPool.get(), - .metricLogger = dwio::common::MetricsLog::voidLog(), - .stats = ioStats_.back().get()}), + { + .bufferWrite = false, + .connectorProperties = hiveConfig_->config(), + .fileCreateConfig = hiveConfig_->writeFileCreateConfig(), + .pool = writerInfo_.back()->sinkPool.get(), + .metricLogger = dwio::common::MetricsLog::voidLog(), + .stats = ioStats_.back().get(), + }), options); writer = maybeCreateBucketSortWriter(std::move(writer)); writers_.emplace_back(std::move(writer)); diff --git a/velox/connectors/hive/HiveDataSink.h b/velox/connectors/hive/HiveDataSink.h index d3cea0f886c6..322e44cb6786 100644 --- a/velox/connectors/hive/HiveDataSink.h +++ b/velox/connectors/hive/HiveDataSink.h @@ -186,9 +186,7 @@ FOLLY_ALWAYS_INLINE std::ostream& operator<<( class HiveInsertTableHandle; using HiveInsertTableHandlePtr = std::shared_ptr; -/** - * Represents a request for Hive write. - */ +/// Represents a request for Hive write. class HiveInsertTableHandle : public ConnectorInsertTableHandle { public: HiveInsertTableHandle( @@ -198,13 +196,16 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { dwio::common::FileFormat::DWRF, std::shared_ptr bucketProperty = nullptr, std::optional compressionKind = {}, - const std::unordered_map& serdeParameters = {}) + const std::unordered_map& serdeParameters = {}, + const std::shared_ptr& writerOptions = + nullptr) : inputColumns_(std::move(inputColumns)), locationHandle_(std::move(locationHandle)), tableStorageFormat_(tableStorageFormat), bucketProperty_(std::move(bucketProperty)), compressionKind_(compressionKind), - serdeParameters_(serdeParameters) { + serdeParameters_(serdeParameters), + writerOptions_(writerOptions) { if (compressionKind.has_value()) { VELOX_CHECK( compressionKind.value() != common::CompressionKind_MAX, @@ -235,6 +236,10 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { return serdeParameters_; } + const std::shared_ptr& writerOptions() const { + return writerOptions_; + } + bool supportsMultiThreading() const override { return true; } @@ -262,6 +267,7 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { const std::shared_ptr bucketProperty_; const std::optional compressionKind_; const std::unordered_map serdeParameters_; + const std::shared_ptr writerOptions_; }; /// Parameters for Hive writers. diff --git a/velox/dwio/common/Options.h b/velox/dwio/common/Options.h index 199d2cb1afc8..26d6e7259eac 100644 --- a/velox/dwio/common/Options.h +++ b/velox/dwio/common/Options.h @@ -608,6 +608,8 @@ struct WriterOptions { std::optional parquetWriteTimestampUnit; std::optional zlibCompressionLevel; std::optional zstdCompressionLevel; + + virtual ~WriterOptions() = default; }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/WriterFactory.h b/velox/dwio/common/WriterFactory.h index cb59c47bfc7f..cebb269d94d4 100644 --- a/velox/dwio/common/WriterFactory.h +++ b/velox/dwio/common/WriterFactory.h @@ -50,7 +50,7 @@ class WriterFactory { /// @return writer object virtual std::unique_ptr createWriter( std::unique_ptr sink, - const dwio::common::WriterOptions& options) = 0; + std::shared_ptr options) = 0; private: const FileFormat format_; diff --git a/velox/dwio/dwrf/writer/Writer.cpp b/velox/dwio/dwrf/writer/Writer.cpp index e6e4257d77bd..001a6b0e1fbb 100644 --- a/velox/dwio/dwrf/writer/Writer.cpp +++ b/velox/dwio/dwrf/writer/Writer.cpp @@ -855,8 +855,8 @@ dwrf::WriterOptions getDwrfOptions(const dwio::common::WriterOptions& options) { std::unique_ptr DwrfWriterFactory::createWriter( std::unique_ptr sink, - const dwio::common::WriterOptions& options) { - auto dwrfOptions = getDwrfOptions(options); + std::shared_ptr options) { + auto dwrfOptions = getDwrfOptions(*options); return std::make_unique(std::move(sink), dwrfOptions); } diff --git a/velox/dwio/dwrf/writer/Writer.h b/velox/dwio/dwrf/writer/Writer.h index f3f67147f82f..ee480e9576db 100644 --- a/velox/dwio/dwrf/writer/Writer.h +++ b/velox/dwio/dwrf/writer/Writer.h @@ -219,7 +219,7 @@ class DwrfWriterFactory : public dwio::common::WriterFactory { std::unique_ptr createWriter( std::unique_ptr sink, - const dwio::common::WriterOptions& options) override; + std::shared_ptr options) override; }; } // namespace facebook::velox::dwrf diff --git a/velox/dwio/parquet/writer/Writer.cpp b/velox/dwio/parquet/writer/Writer.cpp index a0ba3a6cc6a5..3d4cebac4a98 100644 --- a/velox/dwio/parquet/writer/Writer.cpp +++ b/velox/dwio/parquet/writer/Writer.cpp @@ -416,10 +416,10 @@ void Writer::setMemoryReclaimers() { std::unique_ptr ParquetWriterFactory::createWriter( std::unique_ptr sink, - const dwio::common::WriterOptions& options) { - auto parquetOptions = getParquetOptions(options); + std::shared_ptr options) { + auto parquetOptions = getParquetOptions(*options); return std::make_unique( - std::move(sink), parquetOptions, asRowType(options.schema)); + std::move(sink), parquetOptions, asRowType(options->schema)); } } // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/writer/Writer.h b/velox/dwio/parquet/writer/Writer.h index 7f1886708a2b..99d0ea3bd0a9 100644 --- a/velox/dwio/parquet/writer/Writer.h +++ b/velox/dwio/parquet/writer/Writer.h @@ -171,7 +171,7 @@ class ParquetWriterFactory : public dwio::common::WriterFactory { std::unique_ptr createWriter( std::unique_ptr sink, - const dwio::common::WriterOptions& options) override; + std::shared_ptr options) override; }; } // namespace facebook::velox::parquet diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index ef687de8ba89..99791ff0cb63 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -34,7 +34,6 @@ using namespace facebook::velox; namespace facebook::velox::exec::test { - namespace { void writeToFile( @@ -43,9 +42,9 @@ void writeToFile( memory::MemoryPool* pool) { VELOX_CHECK_GT(data.size(), 0); - dwio::common::WriterOptions options; - options.schema = data[0]->type(); - options.memoryPool = pool; + auto options = std::make_shared(); + options->schema = data[0]->type(); + options->memoryPool = pool; auto writeFile = std::make_unique(path, true, false); auto sink = @@ -159,6 +158,7 @@ class ServerResponse { private: folly::dynamic response_; }; + } // namespace PrestoQueryRunner::PrestoQueryRunner( diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index 9a538b76313c..c890b7838ed2 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -23,6 +23,7 @@ #include "velox/vector/ComplexVector.h" namespace facebook::velox::exec::test { + template T extractSingleValue(const std::vector& data) { auto simpleVector = data[0]->childAt(0)->as>(); diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index a1db6fb23544..0232124f778a 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -383,7 +383,8 @@ PlanBuilder& PlanBuilder::tableWrite( const dwio::common::FileFormat fileFormat, const std::vector& aggregates, const std::string& connectorId, - const std::unordered_map& serdeParameters) { + const std::unordered_map& serdeParameters, + const std::shared_ptr& writerOptions) { VELOX_CHECK_NOT_NULL(planNode_, "TableWrite cannot be the source node"); auto rowType = planNode_->outputType(); @@ -418,7 +419,8 @@ PlanBuilder& PlanBuilder::tableWrite( fileFormat, bucketProperty, common::CompressionKind_NONE, - serdeParameters); + serdeParameters, + writerOptions); auto insertHandle = std::make_shared(connectorId, hiveHandle); diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 471940f6766e..2cc6ca1af766 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -435,6 +435,9 @@ class PlanBuilder { /// @param fileFormat File format to use for the written data. /// @param aggregates Aggregations for column statistics collection during /// write. + /// @param connectorId Name used to register the connector. + /// @param serdeParameters Additional parameters passed to the writer. + /// @param Option objects passed to the writer. PlanBuilder& tableWrite( const std::string& outputDirectoryPath, const std::vector& partitionBy, @@ -446,7 +449,9 @@ class PlanBuilder { dwio::common::FileFormat::DWRF, const std::vector& aggregates = {}, const std::string& connectorId = "test-hive", - const std::unordered_map& serdeParameters = {}); + const std::unordered_map& serdeParameters = {}, + const std::shared_ptr& writerOptions = + nullptr); /// Add a TableWriteMergeNode. PlanBuilder& tableWriteMerge(