diff --git a/README.md b/README.md index 90f57ba7d1a..cf67d3ad304 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ The simplest way to use MyScaleDB is to create an instance on MyScale Cloud serv To quickly get a MyScaleDB instance up and running, simply pull and run the latest Docker image: ```bash -docker run --name myscaledb --net=host myscale/myscaledb:1.7.1 +docker run --name myscaledb --net=host myscale/myscaledb:1.8.0 ``` >Note: Myscale's default configuration only allows localhost ip access. For the docker run startup method, you need to specify `--net=host` to access services deployed in docker mode on the current node. @@ -114,7 +114,7 @@ version: '3.7' services: myscaledb: - image: myscale/myscaledb:1.7.1 + image: myscale/myscaledb:1.8.0 tty: true ports: - '8123:8123' diff --git a/cmake/autogenerated_myscale_versions.txt b/cmake/autogenerated_myscale_versions.txt index 1ed74ebbb86..014966f94a6 100644 --- a/cmake/autogenerated_myscale_versions.txt +++ b/cmake/autogenerated_myscale_versions.txt @@ -3,9 +3,9 @@ # NOTE: has nothing common with DBMS_TCP_PROTOCOL_VERSION, # only DBMS_TCP_PROTOCOL_VERSION should be incremented on protocol changes. SET(MYSCALE_VERSION_MAJOR 1) -SET(MYSCALE_VERSION_MINOR 7) -SET(MYSCALE_VERSION_PATCH 1) -SET(MYSCALE_VERSION_DESCRIBE myscale-v1.7.1) -SET(MYSCALE_VERSION_STRING 1.7.1) +SET(MYSCALE_VERSION_MINOR 8) +SET(MYSCALE_VERSION_PATCH 0) +SET(MYSCALE_VERSION_DESCRIBE myscale-v1.8.0) +SET(MYSCALE_VERSION_STRING 1.8.0) # end of autochange diff --git a/rust/supercrate/CMakeLists.txt b/rust/supercrate/CMakeLists.txt index a2e9bfe3823..146dc41b083 100644 --- a/rust/supercrate/CMakeLists.txt +++ b/rust/supercrate/CMakeLists.txt @@ -69,6 +69,10 @@ target_include_directories(supercrate_cxxbridge ${cxx_include} ) +#-Wno-dollar-in-identifier-extension -Wno-unused-macros +target_compile_options(supercrate_cxxbridge PUBLIC -Wno-dollar-in-identifier-extension) +target_compile_options(supercrate_cxxbridge PUBLIC -Wno-unused-macros) + # Create total target with alias with given namespace add_library(supercrate-total INTERFACE) target_link_libraries(supercrate-total diff --git a/rust/supercrate/libs/tantivy_search b/rust/supercrate/libs/tantivy_search index ac72130d765..6be234890b5 160000 --- a/rust/supercrate/libs/tantivy_search +++ b/rust/supercrate/libs/tantivy_search @@ -1 +1 @@ -Subproject commit ac72130d7655c942faad01580cd67221e749fa62 +Subproject commit 6be234890b53caed2044518abb69737c48ac99fc diff --git a/src/Core/Settings.h b/src/Core/Settings.h index bd773d22fe0..3516ca8c727 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -922,6 +922,8 @@ class IColumn; M(Bool, optimize_prefilter_in_search, true, "Enable prewhere optimization for vector or text search if some partition columns in prewhere condition.", 0) \ M(UInt64, max_search_result_window, 10000, "The maximum value of n + m in limit clause for pagination in vector/text/hybrid search", 0) \ M(Bool, dfs_query_then_fetch, false, "Enable Distributed Frequency Search (DFS) query to gather global statistical info for accurate BM25 calculation.", 0) \ + M(UInt64, distances_top_k_multiply_factor, 3, "Multiply k in limit by this factor for the top_k in multiple distance functions", 0) \ + M(UInt64, parallel_reading_prefilter_option, 1, "Control parallel reading prefilter options for vector/text/hybrid search in SELECT queries with where clause. 0 - disable, 1 - adaptive enable depending on mark ranges and row count. 2 - always enable.", 0) \ // End of COMMON_SETTINGS // Please add settings related to formats into the FORMAT_FACTORY_SETTINGS and move obsolete settings to OBSOLETE_SETTINGS. diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 1ca28743c3f..051399bbbb6 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -4151,19 +4151,19 @@ ReadSettings Context::getReadSettings() const return res; } -std::optional Context::getVecScanDescription() const +MutableVSDescriptionsPtr Context::getVecScanDescriptions() const { - return vector_scan_description; + return right_vector_scan_descs; } -void Context::setVecScanDescription(VSDescription & vec_scan_desc) const +void Context::setVecScanDescriptions(MutableVSDescriptionsPtr vec_scan_descs) const { - vector_scan_description = vec_scan_desc; + right_vector_scan_descs = vec_scan_descs; } -void Context::resetVecScanDescription() const +void Context::resetVecScanDescriptions() const { - vector_scan_description.reset(); + right_vector_scan_descs.reset(); } TextSearchInfoPtr Context::getTextSearchInfo() const diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 822d4ee566c..7515647ad2f 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -408,7 +408,7 @@ class Context: public std::enable_shared_from_this /// TODO: will be enhanced similar as scalars. /// Used when vector scan func exists in right joined table - mutable std::optional vector_scan_description; + mutable MutableVSDescriptionsPtr right_vector_scan_descs; mutable TextSearchInfoPtr right_text_search_info; mutable HybridSearchInfoPtr right_hybrid_search_info; @@ -1160,9 +1160,9 @@ class Context: public std::enable_shared_from_this ParallelReplicasMode getParallelReplicasMode() const; /// Used for vector scan functions - std::optional getVecScanDescription() const; - void setVecScanDescription(VSDescription & vec_scan_desc) const; - void resetVecScanDescription() const; + MutableVSDescriptionsPtr getVecScanDescriptions() const; + void setVecScanDescriptions(MutableVSDescriptionsPtr vec_scan_descs) const; + void resetVecScanDescriptions() const; /// Used for text search functions TextSearchInfoPtr getTextSearchInfo() const; diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index f68ba74b201..5a6933ef7b5 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -89,6 +89,7 @@ #include #include #include +#include #include namespace DB @@ -143,10 +144,15 @@ inline void checkTantivyIndex([[maybe_unused]]const StorageSnapshotPtr & storage for (const auto & index_desc : metadata_snapshot->getSecondaryIndices()) { /// Find tantivy inverted index on the search column - if (index_desc.type == TANTIVY_INDEX_NAME && index_desc.column_names.size() == 1 && index_desc.column_names[0] == text_column_name) + if (index_desc.type == TANTIVY_INDEX_NAME) { - find_tantivy_index = true; - break; + auto & column_names = index_desc.column_names; + /// Support search on a column in a multi-columns index + if (std::find(column_names.begin(), column_names.end(), text_column_name) != column_names.end()) + { + find_tantivy_index = true; + break; + } } } } @@ -561,20 +567,22 @@ void ExpressionAnalyzer::analyzeVectorScan(ActionsDAGPtr & temp_actions) { if (syntax->search_func_type == HybridSearchFuncType::VECTOR_SCAN && !syntax->hybrid_search_funcs.empty()) has_vector_scan = makeVectorScanDescriptions(temp_actions); - else if (auto vec_scan_desc = getContext()->getVecScanDescription()) + else if (auto vec_scan_descs = getContext()->getVecScanDescriptions()) { if (syntax->storage_snapshot) { LOG_DEBUG(getLogger(), "[analyzeVectorScan] Get vector scan function from right table"); /// vector search column exists in right joined table - vector_scan_descriptions.emplace_back(*vec_scan_desc); + vector_scan_descriptions = *vec_scan_descs; has_vector_scan = true; } } /// Fill in dim and recognize VectorSearchType from metadata if (has_vector_scan) { - getAndCheckVectorScanInfoFromMetadata(syntax->storage_snapshot, vector_scan_descriptions[0], getContext()); + /// Support multiple distance functions + for (auto & vector_scan_desc : vector_scan_descriptions) + getAndCheckVectorScanInfoFromMetadata(syntax->storage_snapshot, vector_scan_desc, getContext()); } } @@ -620,7 +628,7 @@ void ExpressionAnalyzer::analyzeHybridSearch(ActionsDAGPtr & temp_actions) if (!syntax->is_remote_storage && hybrid_search_info->text_search_info) checkTantivyIndex(syntax->storage_snapshot, hybrid_search_info->text_search_info->text_column_name); - /// Get vector search type and dim from metadata, check paramaters in vector scan and add to vector_paramters + /// Get vector search type and dim from metadata, check paramaters in vector scan and add to vector_parameters VSDescription & vec_scan_desc = const_cast(hybrid_search_info->vector_scan_info->vector_scan_descs[0]); @@ -877,7 +885,9 @@ VSDescription ExpressionAnalyzer::commonMakeVectorScanDescription( const String & function_col_name, ASTPtr query_column, ASTPtr query_vector, - int topk) + int topk, + String vector_scan_metric_type, + Search::DataType vector_search_type) { VSDescription vector_scan_desc; vector_scan_desc.column_name = function_col_name; @@ -941,13 +951,13 @@ VSDescription ExpressionAnalyzer::commonMakeVectorScanDescription( vector_scan_desc.query_column_name); /// vector search type from syntax result - vector_scan_desc.vector_search_type = syntax->vector_search_type; + vector_scan_desc.vector_search_type = vector_search_type; /// top_k is get from limit N vector_scan_desc.topk = topk; /// Pass the correct direction to vector_scan_desc according to metric_type - vector_scan_desc.direction = Poco::toUpper(syntax->vector_scan_metric_type) == "IP" ? -1 : 1; + vector_scan_desc.direction = Poco::toUpper(vector_scan_metric_type) == "IP" ? -1 : 1; return vector_scan_desc; } @@ -955,31 +965,38 @@ VSDescription ExpressionAnalyzer::commonMakeVectorScanDescription( /// create vector scan descriptions, mainly record the column name and parameters bool ExpressionAnalyzer::makeVectorScanDescriptions(ActionsDAGPtr & actions) { - for (const ASTFunction * node : hybrid_search_funcs()) + for (size_t i = 0; i < hybrid_search_funcs().size(); ++i) { - if (node->arguments) - getRootActionsNoMakeSet(node->arguments, actions); + const ASTFunction * node = hybrid_search_funcs()[i]; + const ASTs & arguments = node->arguments ? node->arguments->children : ASTs(); // arguments 0 indicates the vector name; arguments 1 indicates the specific vector content. if (arguments.size() != 2) { - throw Exception(ErrorCodes::BAD_ARGUMENTS, - "wrong argument number in distance function"); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "wrong argument number in distance function"); } - auto vector_scan_desc = commonMakeVectorScanDescription(actions, node->getColumnName(), arguments[0], arguments[1], static_cast(syntax->limit_length)); + getRootActionsNoMakeSet(node->arguments, actions); + + auto vector_scan_desc = commonMakeVectorScanDescription(actions, node->getColumnName(), arguments[0], arguments[1], + static_cast(syntax->limit_length), syntax->vector_scan_metric_types[i], syntax->vector_search_types[i]); /// Save parameters, parse and check parameters will be done in analyzeVectorScan() vector_scan_desc.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters, "", getContext()) : Array(); - LOG_DEBUG(getLogger(), "[makeVectorScanDescriptions] create vector scan function: {}", node->name); + LOG_DEBUG(getLogger(), "[makeVectorScanDescriptions] create vector scan function: {}, column name:{}", node->name, node->getColumnName()); - if (syntax->hybrid_search_from_right_table) - { - analyzedJoin().setVecScanDescription(vector_scan_desc); - } - else - vector_scan_descriptions.push_back(vector_scan_desc); + vector_scan_descriptions.push_back(vector_scan_desc); + } + + /// Support multiple distance functions + /// If vector columns are from right table, save the vector scan descriptions to analyzedJoin(). + if (syntax->hybrid_search_from_right_table) + { + auto vector_scan_descs_ptr = std::make_shared(vector_scan_descriptions); + analyzedJoin().setVecScanDescriptions(vector_scan_descs_ptr); + + vector_scan_descriptions.clear(); } return !vector_scan_descriptions.empty(); @@ -1209,7 +1226,8 @@ bool ExpressionAnalyzer::makeHybridSearchInfo(ActionsDAGPtr & actions) /// make VSDescription for HybridSearchInfo { getRootActionsNoMakeSet(arguments[2], actions); - auto vector_scan_desc = commonMakeVectorScanDescription(actions, "distance_func", arguments[0], arguments[2], num_candidates); + auto vector_scan_desc = commonMakeVectorScanDescription(actions, "distance_func", arguments[0], arguments[2], + num_candidates, syntax->vector_scan_metric_types[0], syntax->vector_search_types[0]); /// Save vector_scan_parameter to vector_scan_desc's parameters if (!vector_scan_parameter.empty()) @@ -1738,10 +1756,10 @@ static std::unique_ptr buildJoinedPlan( { /// Add vector scan description to Context for subquery of joined table bool has_vector_scan = false; - if (auto vec_scan_desc = analyzed_join.getVecScanDescription()) + if (auto vec_scan_descs = analyzed_join.getVecScanDescriptions()) { has_vector_scan = true; - context->setVecScanDescription(*vec_scan_desc); + context->setVecScanDescriptions(vec_scan_descs); } /// Add text search info to Context for subquery of joined table @@ -1806,7 +1824,7 @@ static std::unique_ptr buildJoinedPlan( /// Reset vector scan description if (has_vector_scan) - context->resetVecScanDescription(); + context->resetVecScanDescriptions(); else if (has_text_search) context->resetTextSearchInfo(); else if (has_hybrid_search) @@ -2351,6 +2369,10 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActio { String result_name = ast->getAliasOrColumnName(); + /// Skip score_type column used for distributed hybrid search fusion + if (hasHybridSearch() && (result_name == SCORE_TYPE_COLUMN.name)) + continue; + if (required_result_columns_set.empty() || required_result_columns_set.contains(result_name)) { std::string source_name = ast->getColumnName(); diff --git a/src/Interpreters/ExpressionAnalyzer.h b/src/Interpreters/ExpressionAnalyzer.h index 4b8de10bc2a..5fc1bef6e18 100644 --- a/src/Interpreters/ExpressionAnalyzer.h +++ b/src/Interpreters/ExpressionAnalyzer.h @@ -241,7 +241,9 @@ class ExpressionAnalyzer : protected ExpressionAnalyzerData, private boost::nonc const String & function_col_name, ASTPtr query_column, ASTPtr query_vector, - int topk); + int topk, + String vector_scan_metric_type, + Search::DataType vector_search_type); void analyzeHybridSearch(ActionsDAGPtr & temp_actions); bool makeHybridSearchInfo(ActionsDAGPtr & actions); diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index ce646ffdaa1..06ead611081 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -101,8 +101,11 @@ #include "config_version.h" #include +#include + #if USE_TANTIVY_SEARCH # include +# include #endif namespace DB @@ -781,18 +784,35 @@ InterpreterSelectQuery::InterpreterSelectQuery( analyze(shouldMoveToPrewhere()); #if USE_TANTIVY_SEARCH - if (!options.only_analyze && storage && query_analyzer->getAnalyzedData().text_search_info && context->getSettingsRef().dfs_query_then_fetch) + if (!options.only_analyze && storage && context->getSettingsRef().dfs_query_then_fetch) { - /// Collect global statistics information of all shards used in BM25 calculation when text search is distributed - if (auto distributed_storage = std::dynamic_pointer_cast(storage)) + /// Collect global statistics information of all shards used in BM25 calculation when text/hybrid search is distributed + auto distributed_storage = std::dynamic_pointer_cast(storage); + + if (distributed_storage) { - collectStatisticForBM25Calculation( - context, - distributed_storage->getClusterName(), - distributed_storage->getRemoteDatabaseName(), - distributed_storage->getRemoteTableName(), - query_analyzer->getAnalyzedData().text_search_info->text_column_name, - query_analyzer->getAnalyzedData().text_search_info->query_text); + String text_column_name, query_text; + if (query_analyzer->getAnalyzedData().text_search_info) + { + text_column_name = query_analyzer->getAnalyzedData().text_search_info->text_column_name; + query_text = query_analyzer->getAnalyzedData().text_search_info->query_text; + } + else if (query_analyzer->getAnalyzedData().hybrid_search_info) + { + text_column_name = query_analyzer->getAnalyzedData().hybrid_search_info->text_search_info->text_column_name; + query_text = query_analyzer->getAnalyzedData().hybrid_search_info->text_search_info->query_text; + } + + if (!text_column_name.empty()) + { + collectStatisticForBM25Calculation( + context, + distributed_storage->getClusterName(), + distributed_storage->getRemoteDatabaseName(), + distributed_storage->getRemoteTableName(), + text_column_name, + query_text); + } } } #endif @@ -1820,7 +1840,16 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, std::optional

getShardsInfo().size() > 1) + { + executeFusionSorted(query_plan); + } + else + { + executeMergeSorted(query_plan, "after aggregation stage for ORDER BY"); + } + } else if (!expressions.first_stage && !expressions.need_aggregate && !expressions.has_window @@ -2403,7 +2432,10 @@ void InterpreterSelectQuery::executeFetchColumns(QueryProcessingStage::Enum proc /// If there is vector scan in the outer query and main table is subquery, save the vector scan description to subquery. if (query_analyzer->hasVectorScan()) - context->setVecScanDescription(query_analyzer->vectorScanDescs().front()); + { + auto vector_scan_desc_ptr = std::make_shared(query_analyzer->vectorScanDescs()); + context->setVecScanDescriptions(vector_scan_desc_ptr); + } interpreter_subquery = std::make_unique( subquery, getSubqueryContext(context), @@ -2412,7 +2444,7 @@ void InterpreterSelectQuery::executeFetchColumns(QueryProcessingStage::Enum proc interpreter_subquery->addStorageLimits(storage_limits); if (query_analyzer->hasVectorScan()) - context->resetVecScanDescription(); + context->resetVecScanDescriptions(); if (query_analyzer->hasAggregation()) interpreter_subquery->ignoreWithTotals(); @@ -2931,6 +2963,47 @@ void InterpreterSelectQuery::executeOrder(QueryPlan & query_plan, InputOrderInfo query_plan.addStep(std::move(sorting_step)); } +void InterpreterSelectQuery::executeFusionSorted(QueryPlan & query_plan) +{ + const auto & query = getSelectQuery(); + SortDescription sort_description = getSortDescription(query, context); + UInt64 limit = getLimitForSorting(query, context); + + /// Before Distributed HybridSearch Fusion, ORDER BY SCORE_TYPE_COLUMN ASC, HYBRID_SEARCH_SCORE_COLUMN_NAME DESC + SortDescription before_fusion_order_descr; + before_fusion_order_descr.push_back(SortColumnDescription(SCORE_TYPE_COLUMN.name, 1, 1)); + before_fusion_order_descr.push_back(SortColumnDescription(HYBRID_SEARCH_SCORE_COLUMN_NAME, -1, 1)); + + const Settings & settings = context->getSettingsRef(); + SortingStep::Settings sort_settings(*context); + + /// Merge and sort the sorted blocks from different shards that include distance or bm25 score. + /// Set limit = 0 to get all global distance and BM25 top-k results. + auto sorting_step = std::make_unique( + query_plan.getCurrentDataStream(), + before_fusion_order_descr, + 0, + sort_settings, + settings.optimize_sorting_by_input_stream_properties); + + sorting_step->setStepDescription("Sorting before HybridSearch Fusion"); + query_plan.addStep(std::move(sorting_step)); + + /// Fuse the sorted blocks to HybridSearch top-k results. + auto fusion_sorting = std::make_unique( + query_plan.getCurrentDataStream(), + std::move(sort_description), + limit, + limit * context->getSettingsRef().hybrid_search_top_k_multiple_base, + query_info.hybrid_search_info->fusion_type, + context->getSettingsRef().hybrid_search_fusion_k, + context->getSettingsRef().hybrid_search_fusion_weight, + query_info.hybrid_search_info->vector_scan_info->vector_scan_descs[0].direction); + + fusion_sorting->setStepDescription("HybridSearch Fusion and Sorting"); + query_plan.addStep(std::move(fusion_sorting)); +} + void InterpreterSelectQuery::executeMergeSorted(QueryPlan & query_plan, const std::string & description) { diff --git a/src/Interpreters/InterpreterSelectQuery.h b/src/Interpreters/InterpreterSelectQuery.h index bb1c232f31f..3274c6400b3 100644 --- a/src/Interpreters/InterpreterSelectQuery.h +++ b/src/Interpreters/InterpreterSelectQuery.h @@ -189,6 +189,9 @@ class InterpreterSelectQuery : public IInterpreterUnionOrSelectQuery void executeSubqueriesInSetsAndJoins(QueryPlan & query_plan); bool autoFinalOnQuery(ASTSelectQuery & select_query); + /// Distributed HybridSearch Fusion and Sorting + void executeFusionSorted(QueryPlan & query_plan); + enum class Modificator { ROLLUP = 0, diff --git a/src/Interpreters/TableJoin.cpp b/src/Interpreters/TableJoin.cpp index e10acf440b0..18f07c96999 100644 --- a/src/Interpreters/TableJoin.cpp +++ b/src/Interpreters/TableJoin.cpp @@ -129,7 +129,7 @@ void TableJoin::resetCollected() renames.clear(); left_type_map.clear(); right_type_map.clear(); - right_vector_scan_description.reset(); + right_vector_scan_descs.reset(); } void TableJoin::addUsingKey(const ASTPtr & ast) @@ -743,14 +743,14 @@ void TableJoin::resetToCross() this->table_join.kind = JoinKind::Cross; } -std::optional TableJoin::getVecScanDescription() const +MutableVSDescriptionsPtr TableJoin::getVecScanDescriptions() const { - return right_vector_scan_description; + return right_vector_scan_descs; } -void TableJoin::setVecScanDescription(VSDescription & vec_scan_desc) const +void TableJoin::setVecScanDescriptions(MutableVSDescriptionsPtr vec_scan_descs) const { - right_vector_scan_description = vec_scan_desc; + right_vector_scan_descs = vec_scan_descs; } TextSearchInfoPtr TableJoin::getTextSearchInfoPtr() const diff --git a/src/Interpreters/TableJoin.h b/src/Interpreters/TableJoin.h index 46a6bd34cc0..1f4ef05a405 100644 --- a/src/Interpreters/TableJoin.h +++ b/src/Interpreters/TableJoin.h @@ -152,7 +152,7 @@ class TableJoin NamesAndTypesList columns_added_by_join; /// vector scan functions from joined table - mutable std::optional right_vector_scan_description; + mutable MutableVSDescriptionsPtr right_vector_scan_descs; /// text search info from joined table mutable TextSearchInfoPtr right_text_search_info; /// hybrid search info from joined table @@ -365,8 +365,8 @@ class TableJoin std::shared_ptr getStorageKeyValue() { return right_kv_storage; } /// Used for vector scan functions - std::optional getVecScanDescription() const; - void setVecScanDescription(VSDescription & vec_scan_desc) const; + MutableVSDescriptionsPtr getVecScanDescriptions() const; + void setVecScanDescriptions(MutableVSDescriptionsPtr vec_scan_descs) const; /// Used for text search function TextSearchInfoPtr getTextSearchInfoPtr() const; diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index c540e084a8b..2ee4aa45e6a 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -59,6 +59,7 @@ #include "Common/Allocator.h" #include #include +#include #include #include @@ -74,6 +75,7 @@ #include #include #include +#include #include @@ -902,60 +904,99 @@ using RewriteShardNumVisitor = InDepthNodeVisitor; /// Get hybrid search related functions(distance, batch_distance, TextSearch and HybridSearch), remove duplicated functions void getHybridSearchFunctions( ASTPtr & query, - const ASTSelectQuery & select_query, + ASTSelectQuery * select_query, std::vector & hybrid_search_functions, HybridSearchFuncType & search_func_type) { GetHybridSearchVisitor::Data data; GetHybridSearchVisitor(data).visit(query); + /// Mark if query contains multiple distances + bool has_multiple_distances = false; + size_t hybrid_search_func_count = data.vector_scan_funcs.size() + data.text_search_func.size() + data.hybrid_search_func.size(); if (hybrid_search_func_count == 0) return ; else if (hybrid_search_func_count > 1) - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support more than one function among vector scan, text search, or hybrid search in one query now"); + { + /// Support multiple vector scan funcs + if (data.vector_scan_funcs.size() != hybrid_search_func_count) + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Only support multiple distance functions in one query now"); + + has_multiple_distances = true; + } String search_func_name; - if (data.vector_scan_funcs.size() == 1) + if (data.vector_scan_funcs.size() >= 1) { hybrid_search_functions = data.vector_scan_funcs; search_func_type = HybridSearchFuncType::VECTOR_SCAN; - search_func_name = "distance"; + search_func_name = DISTANCE_FUNCTION; + + if (has_multiple_distances) + { + /// multiple distances: need to set mark flag in ASTFunction + for (const auto & vector_scan_func : data.all_multiple_vector_scan_funcs) + vector_scan_func->is_from_multiple_distances = true; + } } else if (data.text_search_func.size() == 1) { hybrid_search_functions = data.text_search_func; search_func_type = HybridSearchFuncType::TEXT_SEARCH; - search_func_name = "TextSearch"; + search_func_name = TEXT_SEARCH_FUNCTION; } else if (data.hybrid_search_func.size() == 1) { hybrid_search_functions = data.hybrid_search_func; search_func_type = HybridSearchFuncType::HYBRID_SEARCH; - search_func_name = "HybridSearch"; + search_func_name = HYBRID_SEARCH_FUNCTION; } - if (!select_query.orderBy()) + /// Remove the restriction that distance() function must exist in order by clause. + if (search_func_type == HybridSearchFuncType::VECTOR_SCAN) { - /// TODO: Will be removed when distance functions are implemented - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without ORDER BY clause", search_func_name); - } + /// Add default order by clause if not specified + if (!select_query->orderBy()) + { + ASTPtr default_order_by_ast = std::make_shared(); - bool is_batch = hybrid_search_functions.size() == 1 && isBatchDistance(hybrid_search_functions[0]->getColumnName()); - if (!is_batch && !select_query.limitLength()) - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without LIMIT N clause", search_func_name); - else if (is_batch && !select_query.limitByLength()) - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support batch {} function without LIMIT N BY clause", search_func_name); + auto virtual_part = std::make_shared(); + virtual_part->direction = 1; + virtual_part->children.emplace_back(std::make_shared("_part")); + + auto virtual_row_id = std::make_shared(); + virtual_row_id->direction = 1; + virtual_row_id->children.emplace_back(std::make_shared("_part_offset")); - if (select_query.orderBy()) + default_order_by_ast->children.push_back(std::move(virtual_part)); + default_order_by_ast->children.push_back(std::move(virtual_row_id)); + + select_query->setExpression(ASTSelectQuery::Expression::ORDER_BY, std::move(default_order_by_ast)); + } + } + else /// TextSearch/HybridSearch { - GetHybridSearchVisitor::Data order_by_data; - GetHybridSearchVisitor(order_by_data).visit(select_query.orderBy()); + if (!select_query->orderBy()) + { + throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without ORDER BY clause", search_func_name); + } + else + { + GetHybridSearchVisitor::Data order_by_data; + GetHybridSearchVisitor(order_by_data).visit(select_query->orderBy()); - auto search_func_count = order_by_data.vector_scan_funcs.size() + order_by_data.text_search_func.size() + order_by_data.hybrid_search_func.size(); - if (search_func_count != 1) - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support without {} function inside ORDER BY clause", search_func_name); + auto search_func_count = order_by_data.text_search_func.size() + order_by_data.hybrid_search_func.size(); + if (search_func_count != 1) + throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support without {} function inside ORDER BY clause", search_func_name); + } } + + bool is_batch = hybrid_search_functions.size() == 1 && isBatchDistance(hybrid_search_functions[0]->getColumnName()); + if (!is_batch && !select_query->limitLength()) + throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without LIMIT N clause", search_func_name); + else if (is_batch && !select_query->limitByLength()) + throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support batch_{} function without LIMIT N BY clause", search_func_name); } void addSearchFunctionColumnName(const String & func_col_name, NamesAndTypesList & source_columns, ASTSelectQuery * select_query = nullptr) @@ -1002,7 +1043,7 @@ void addSearchFunctionColumnName(const String & func_col_name, NamesAndTypesList } void checkOrderBySortDirection( - String func_name, ASTSelectQuery * select_query, int & sort_direction, int expected_direction, + String func_name, ASTSelectQuery * select_query, int expected_direction, String metric_type = "", bool is_batch = false) { if (!select_query) @@ -1010,7 +1051,10 @@ void checkOrderBySortDirection( auto order_by = select_query->orderBy(); if (!order_by) - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without ORDER BY clause", func_name); + return; /// order by is already checked and handled by getHybridSearchFunctions() + + int sort_direction = 0; + bool find_search_function = false; /// Find the direction for hybrid search func for (const auto & child : order_by->children) @@ -1020,10 +1064,28 @@ void checkOrderBySortDirection( continue; ASTPtr order_expression = order_by_element->children.at(0); - if (!is_batch && isHybridSearchFunc(order_expression->getColumnName())) + if (!is_batch) { - sort_direction = order_by_element->direction; - break; + /// Check cases when search function column is an argument of other functions + if (isHybridSearchFunc(order_expression->getColumnName())) + { + sort_direction = order_by_element->direction; + find_search_function = true; + break; + } + else if (auto * function = order_expression->as()) + { + const ASTs & func_arguments = function->arguments->as().children; + for (const auto & func_arg : func_arguments) + { + if (isHybridSearchFunc(func_arg->getColumnName())) + { + sort_direction = order_by_element->direction; + find_search_function = true; + break; + } + } + } } else if (is_batch) { @@ -1036,6 +1098,7 @@ void checkOrderBySortDirection( if (func_arguments[1]->getColumnName() == "2") { sort_direction = order_by_element->direction; + find_search_function = true; break; } } @@ -1043,7 +1106,8 @@ void checkOrderBySortDirection( } } - if (sort_direction != expected_direction) + /// Only check the direction of search function when found in order by + if (find_search_function && sort_direction != expected_direction) { String vector_scan_error_log = metric_type.empty() ? "" : " when the metric type is " + metric_type; throw Exception(ErrorCodes::SYNTAX_ERROR, @@ -1268,6 +1332,14 @@ bool TreeRewriterResult::collectUsedColumns(const ASTPtr & query, bool is_select const String func_column_name = node->getColumnName(); addSearchFunctionColumnName(func_column_name, source_columns); unknown_required_source_columns.erase(func_column_name); + + /// Add score type column for distributed storage with multiple shards + auto * distributed = dynamic_cast(const_cast(storage.get())); + if (isHybridSearch(func_column_name) && distributed && distributed->getCluster()->getShardsInfo().size() > 1) + { + source_columns.push_back(SCORE_TYPE_COLUMN); + unknown_required_source_columns.erase(SCORE_TYPE_COLUMN.name); + } } } @@ -1513,10 +1585,20 @@ std::optional TreeRewriterResult::collectSearchColumnType( if (!search_column_type) throw Exception(ErrorCodes::BAD_ARGUMENTS, "search column name: {}, type is not exist", search_col_name); - /// Check if table with vector index contains the same column as vector scan function column. + /// Check if table with vector index contains the same column as search function column. if (metadata_snapshot && metadata_snapshot->getColumns().has(func_col_name)) throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support search function on table with column name '{}'", func_col_name); + if (hybrid_search_from_right_table) + { + /// distance func column name should add to right joined table's source columns + addSearchFunctionColumnName(func_col_name, analyzed_join->columns_from_joined_table); + + /// Add distance func column to original_names too + auto & original_names = analyzed_join->original_names; + original_names[func_col_name] = func_col_name; + } + return search_column_type; } @@ -1526,122 +1608,152 @@ void TreeRewriterResult::collectForHybridSearchRelatedFunctions( ContextPtr context) { /// distance function exists in main query's select caluse - if (hybrid_search_funcs.size() == 1) + size_t search_funcs_size = hybrid_search_funcs.size(); + if (search_funcs_size > 0) { String function_name; size_t expected_args_size = 0; - bool is_batch = false; bool has_vector = false; bool has_text = false; - const ASTFunction * node = hybrid_search_funcs[0]; if (search_func_type == HybridSearchFuncType::VECTOR_SCAN) { has_vector = true; expected_args_size = 2; - is_batch = isBatchDistance(node->getColumnName()); - - function_name = is_batch ? "batch_distance" : "distance"; } else if (search_func_type == HybridSearchFuncType::TEXT_SEARCH) { has_text = true; - function_name = "TextSearch"; + function_name = TEXT_SEARCH_FUNCTION; expected_args_size = 2; } else if (search_func_type == HybridSearchFuncType::HYBRID_SEARCH) { has_vector = true; has_text = true; - function_name = "HybridSearch"; + function_name = HYBRID_SEARCH_FUNCTION; expected_args_size = 4; } - const ASTs & arguments = node->arguments ? node->arguments->children : ASTs(); - - /// There is no real search function, hence the checks like input paramters are put here. - if (arguments.size() != expected_args_size) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "wrong argument number in {} function: {}, expected {}", function_name, arguments.size(), expected_args_size); + /// Support multiple distance functions + bool has_multiple_distances = search_funcs_size > 1; + StorageMetadataPtr metadata_snapshot = nullptr; + bool table_is_remote = false; /// Mark if the storage with search column is distributed. - /// Get topK from limit N - limit_length = getTopKFromLimit(select_query, context, is_batch); + vector_scan_metric_types.resize(search_funcs_size); + vector_search_types.resize(search_funcs_size); - if (limit_length == 0) + for (size_t i = 0; i < search_funcs_size; ++i) { - if (is_batch) - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support batch_distance function without LIMIT BY clause"); - else - throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without LIMIT clause", function_name); - } + const ASTFunction * node = hybrid_search_funcs[i]; + bool is_batch = false; - String function_col_name = node->getColumnName(); /// column name of search function - ASTPtr vector_argument; - ASTPtr text_argument; + /// Initialize for vector scan function + if (search_func_type == HybridSearchFuncType::VECTOR_SCAN) + { + is_batch = isBatchDistance(node->getColumnName()); + function_name = is_batch ? "batch_distance" : "distance"; + } - /// Use the first argument in search function - /// vector column in vector scan and hybrid search, text column in text search - if (has_vector) - { - vector_argument = arguments[0]; + if (has_multiple_distances && (is_batch || search_func_type != HybridSearchFuncType::VECTOR_SCAN)) + throw Exception(ErrorCodes::SYNTAX_ERROR, "Only support multiple distance functions in one query"); - /// hybrid search - if (has_text) - text_argument = arguments[1]; - } - else - text_argument = arguments[0]; + const ASTs & arguments = node->arguments ? node->arguments->children : ASTs(); - String vector_col_name = vector_argument ? vector_argument->getColumnName() : ""; - String text_col_name = text_argument ? text_argument->getColumnName() : ""; + /// There is no real search function, hence the checks like input parameters are put here. + if (arguments.size() != expected_args_size) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "wrong argument number in {} function: {}, expected {}", function_name, arguments.size(), expected_args_size); - StorageMetadataPtr metadata_snapshot = nullptr; - bool table_is_remote = false; /// Mark if the storage with search column is distributed. + /// Only need to get k in limit for once + if (i == 0) + { + /// Get topK from limit N + limit_length = getTopKFromLimit(select_query, context, is_batch); - if (has_vector) - { - auto search_vector_column_type = collectSearchColumnType(vector_col_name, function_col_name, tables_with_columns, context, - vector_argument, metadata_snapshot, table_is_remote); + if (limit_length == 0) + { + if (is_batch) + throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support batch_distance function without LIMIT BY clause"); + else + throw Exception(ErrorCodes::SYNTAX_ERROR, "Not support {} function without LIMIT clause", function_name); + } + } - vector_search_type = getSearchIndexDataType(search_vector_column_type->type); + String function_col_name = node->getColumnName(); /// column name of search function + ASTPtr vector_argument; + ASTPtr text_argument; - vector_scan_metric_type = getMetricType(metadata_snapshot, vector_search_type, vector_col_name, context); - } + /// Use the first argument in search function + /// vector column in vector scan and hybrid search, text column in text search + if (has_vector) + { + vector_argument = arguments[0]; - if (has_text) - { - /// Check mappKeys for text column - bool is_mapkeys = false; - if (const ASTFunction * function = text_argument->as()) + /// hybrid search + if (has_text) + text_argument = arguments[1]; + } + else + text_argument = arguments[0]; + + String vector_col_name = vector_argument ? vector_argument->getColumnName() : ""; + String text_col_name = text_argument ? text_argument->getColumnName() : ""; + String vector_scan_metric_type; + + if (has_vector) { - if ((function->name == "mapKeys") && function->arguments) + auto search_vector_column_type = collectSearchColumnType(vector_col_name, function_col_name, tables_with_columns, context, + vector_argument, metadata_snapshot, table_is_remote); + + auto vector_search_type = getSearchIndexDataType(search_vector_column_type->type); + vector_search_types[i] = vector_search_type; + + vector_scan_metric_type = getMetricType(metadata_snapshot, vector_search_type, vector_col_name, context); + vector_scan_metric_types[i] = vector_scan_metric_type; + } + + if (has_text) + { + /// Check mappKeys for text column + bool is_mapkeys = false; + if (const ASTFunction * function = text_argument->as()) { - const auto & function_arguments_list = function->arguments->as()->children; - if (function_arguments_list.size() == 1) + if ((function->name == "mapKeys") && function->arguments) { - text_col_name = function_arguments_list[0]->getColumnName(); - is_mapkeys = true; + const auto & function_arguments_list = function->arguments->as()->children; + if (function_arguments_list.size() == 1) + { + text_col_name = function_arguments_list[0]->getColumnName(); + is_mapkeys = true; + } } } + auto search_text_column_type = collectSearchColumnType(text_col_name, function_col_name, tables_with_columns, context, + text_argument, metadata_snapshot, table_is_remote); + + checkTextSearchColumnDataType(search_text_column_type->type, is_mapkeys); } - auto search_text_column_type = collectSearchColumnType(text_col_name, function_col_name, tables_with_columns, context, - text_argument, metadata_snapshot, table_is_remote); - checkTextSearchColumnDataType(search_text_column_type->type, is_mapkeys); + /// The direction of TextSearch/HybridSearch func in order by should be DESC + if (has_text) + checkOrderBySortDirection(function_name, select_query, -1); + else + { + /// When metric_type = IP in definition of vector index, order by must be DESC. + /// Skip the check when table is distributed or query has multiple distance functions + if (!table_is_remote && !has_multiple_distances) + { + String metric_type = vector_scan_metric_type; + Poco::toUpperInPlace(metric_type); + checkOrderBySortDirection(function_name, select_query, metric_type == "IP" ? -1 : 1, metric_type, is_batch); + } + } } - /// The direction of TextSearch/HybridSearch func in order by should be DESC - if (has_text) - checkOrderBySortDirection(function_col_name, select_query, direction, -1); - else + /// topk in multiple distance functions case should be 3 * k + if (has_multiple_distances) { - /// When metric_type = IP in definition of vector index, order by must be DESC. - /// Skip the check when table is distributed. - if (!table_is_remote) - { - String metric_type = vector_scan_metric_type; - Poco::toUpperInPlace(metric_type); - checkOrderBySortDirection("distance", select_query, direction, metric_type == "IP" ? -1 : 1, metric_type, is_batch); - } + limit_length = limit_length * context->getSettingsRef().distances_top_k_multiply_factor; } } else @@ -1650,8 +1762,11 @@ void TreeRewriterResult::collectForHybridSearchRelatedFunctions( /// Add search func name and type to source columns /// Add search func name to select clauses if not exists String search_func_col_name; - if (auto vector_scan_desc = context->getVecScanDescription()) - search_func_col_name = vector_scan_desc->column_name; + if (auto vector_scan_descs = context->getVecScanDescriptions()) + { + for (const auto & vec_scan_desc : *vector_scan_descs) + addSearchFunctionColumnName(vec_scan_desc.column_name, source_columns, select_query); + } else if (auto text_search_info = context->getTextSearchInfo()) search_func_col_name = text_search_info->function_column_name; else if (auto hybrid_search_info = context->getHybridSearchInfo()) @@ -1774,7 +1889,36 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect( column.name = StorageView::replaceQueryParameterWithValue(column.name, parameter_values, parameter_types); } - getHybridSearchFunctions(query, *select_query, result.hybrid_search_funcs, result.search_func_type); + getHybridSearchFunctions(query, select_query, result.hybrid_search_funcs, result.search_func_type); + + /// Add score_type column and fusion id columns for multiple-shard distributed hybrid search + /// fusion id columns: _shard_num, _part_index, _part_offset + auto * distributed = dynamic_cast(const_cast(result.storage.get())); + if (result.search_func_type == HybridSearchFuncType::HYBRID_SEARCH && distributed + && distributed->getCluster()->getShardsInfo().size() > 1) + { + auto score_type_identifier = std::make_shared(SCORE_TYPE_COLUMN.name); + select_query->select()->children.push_back(score_type_identifier); + + if (!select_query->orderBy()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Hybrid search requires ORDER BY clause."); + + auto shard_num_element = std::make_shared(); + shard_num_element->direction = 1; + shard_num_element->children.emplace_back(makeASTFunction("shardNum")); + + auto part_index_element = std::make_shared(); + part_index_element->direction = 1; + part_index_element->children.emplace_back(std::make_shared("_part_index")); + + auto part_offset_element = std::make_shared(); + part_offset_element->direction = 1; + part_offset_element->children.emplace_back(std::make_shared("_part_offset")); + + select_query->orderBy()->children.push_back(std::move(shard_num_element)); + select_query->orderBy()->children.push_back(std::move(part_index_element)); + select_query->orderBy()->children.push_back(std::move(part_offset_element)); + } /// Special handling for vector scan, text search and hybrid search function result.collectForHybridSearchRelatedFunctions(select_query, tables_with_columns, getContext()); diff --git a/src/Interpreters/TreeRewriter.h b/src/Interpreters/TreeRewriter.h index 965c47b8704..9573b6ec777 100644 --- a/src/Interpreters/TreeRewriter.h +++ b/src/Interpreters/TreeRewriter.h @@ -55,10 +55,9 @@ struct TreeRewriterResult HybridSearchFuncType search_func_type = HybridSearchFuncType::UNKNOWN_FUNC; /// Save vector scan metric_type - String vector_scan_metric_type; - Search::DataType vector_search_type; + std::vector vector_scan_metric_types; + std::vector vector_search_types; UInt64 limit_length = 0; - int direction = 1; /// True if hybrid search function is from right table bool hybrid_search_from_right_table = false; diff --git a/src/Parsers/ASTFunction.cpp b/src/Parsers/ASTFunction.cpp index 1becf500a41..25bdfdea64b 100644 --- a/src/Parsers/ASTFunction.cpp +++ b/src/Parsers/ASTFunction.cpp @@ -491,6 +491,22 @@ void ASTFunction::appendColumnNameImpl(WriteBuffer & ostr) const if (isHybridSearchFunc(name)) { writeString("_func", ostr); + + /// Support multiple distance functions + if (is_from_multiple_distances && isDistance(name)) + { + ostr.write('_'); + if (!alias.empty()) + writeString(alias, ostr); + else + { + Hash hash = getTreeHash(); + writeText(hash.first, ostr); + ostr.write('_'); + writeText(hash.second, ostr); + } + } + return; } diff --git a/src/Parsers/ASTFunction.h b/src/Parsers/ASTFunction.h index 4a036c5e94a..48757a3f790 100644 --- a/src/Parsers/ASTFunction.h +++ b/src/Parsers/ASTFunction.h @@ -74,6 +74,11 @@ class ASTFunction : public ASTWithAlias /// This is used for parameterized view, to identify if name is 'db.view' bool is_compound_name = false; + /// True if query has multiple distance functions + /// Single distance function's column name is 'distance_func' + /// while in multiple cases it's 'distance_func_' + mutable bool is_from_multiple_distances = false; + bool hasSecretParts() const override; protected: diff --git a/src/Storages/ColumnsDescription.cpp b/src/Storages/ColumnsDescription.cpp index 49fd963f267..ae6ed4b924b 100644 --- a/src/Storages/ColumnsDescription.cpp +++ b/src/Storages/ColumnsDescription.cpp @@ -33,6 +33,7 @@ #include #include +#include #include @@ -559,6 +560,12 @@ NamesAndTypesList ColumnsDescription::getByNames(const GetColumnsOptions & optio res.emplace_back(name, std::make_shared()); continue; } + else if (name == SCORE_TYPE_COLUMN.name) + { + res.emplace_back(SCORE_TYPE_COLUMN); + continue; + } + if (isBatchDistance(name)) { auto id_type = std::make_shared(); diff --git a/src/Storages/MergeTree/MergeTreeIndexTantivy.cpp b/src/Storages/MergeTree/MergeTreeIndexTantivy.cpp index 3f83f6cc7c3..a09f3f0a9de 100644 --- a/src/Storages/MergeTree/MergeTreeIndexTantivy.cpp +++ b/src/Storages/MergeTree/MergeTreeIndexTantivy.cpp @@ -789,7 +789,7 @@ void ftsIndexValidator(const IndexDescription & index, bool /*attach*/) String index_json_parameter = index.arguments.empty() ? "{}" : index.arguments[0].get(); - FFIBoolResult json_status = ffi_verify_index_parameter(index_json_parameter); + TANTIVY::FFIBoolResult json_status = TANTIVY::ffi_verify_index_parameter(index_json_parameter); if (json_status.error.is_error) { throw DB::Exception(ErrorCodes::BAD_ARGUMENTS, "{}", std::string(json_status.error.message)); diff --git a/src/Storages/MergeTree/TantivyIndexStore.cpp b/src/Storages/MergeTree/TantivyIndexStore.cpp index ab0c6415608..17772fb1341 100644 --- a/src/Storages/MergeTree/TantivyIndexStore.cpp +++ b/src/Storages/MergeTree/TantivyIndexStore.cpp @@ -651,7 +651,7 @@ bool TantivyIndexStore::getTantivyIndexReader() { LOG_INFO(log, "[getTantivyIndexReader] initializing FTS index reader, FTS index cache directory is {}", index_files_cache_path); this->index_files_manager->deserialize(); - FFIBoolResult load_status = ffi_load_index_reader(index_files_cache_path); + TANTIVY::FFIBoolResult load_status = TANTIVY::ffi_load_index_reader(index_files_cache_path); size_t retry_times = 0; const int base_wait_slots = 50; while (load_status.error.is_error && retry_times < DESERIALIZE_MAX_RETRY_TIMES) @@ -665,7 +665,7 @@ bool TantivyIndexStore::getTantivyIndexReader() retry_times); this->index_files_manager->removeTantivyIndexCacheDirectory(); this->index_files_manager->deserialize(); - load_status = ffi_load_index_reader(index_files_cache_path); + load_status = TANTIVY::ffi_load_index_reader(index_files_cache_path); retry_times += 1; } if (load_status.error.is_error) @@ -710,8 +710,8 @@ bool TantivyIndexStore::getTantivyIndexWriter() return writer_ready; LOG_INFO(log, "[getTantivyIndexWriter] initializing FTS index writer, FTS index cache directory is {}", index_files_cache_path); - FFIBoolResult create_status - = ffi_create_index_with_parameter(index_files_cache_path, index_settings.indexed_columns, index_settings.index_json_parameter); + TANTIVY::FFIBoolResult create_status = TANTIVY::ffi_create_index_with_parameter( + index_files_cache_path, index_settings.indexed_columns, index_settings.index_json_parameter); if (create_status.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(create_status.error.message)); @@ -739,7 +739,7 @@ bool TantivyIndexStore::indexMultiColumnDoc(uint64_t row_id, std::vector getTantivyIndexWriter(); String index_files_cache_path = this->index_files_manager->getTantivyIndexCacheDirectory(); - FFIBoolResult index_status = ffi_index_multi_column_docs(index_files_cache_path, row_id, column_names, docs); + TANTIVY::FFIBoolResult index_status = TANTIVY::ffi_index_multi_column_docs(index_files_cache_path, row_id, column_names, docs); if (index_status.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(index_status.error.message)); @@ -766,7 +766,7 @@ bool TantivyIndexStore::freeTantivyIndexReader() if (this->index_reader_status) { - FFIBoolResult free_status = ffi_free_index_reader(index_files_cache_path); + TANTIVY::FFIBoolResult free_status = TANTIVY::ffi_free_index_reader(index_files_cache_path); if (free_status.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(free_status.error.message)); @@ -789,7 +789,7 @@ bool TantivyIndexStore::freeTantivyIndexWriter() if (getIndexWriterStatus()) { - FFIBoolResult free_status = ffi_free_index_writer(index_files_cache_path); + TANTIVY::FFIBoolResult free_status = TANTIVY::ffi_free_index_writer(index_files_cache_path); if (free_status.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(free_status.error.message)); @@ -821,7 +821,7 @@ void TantivyIndexStore::commitTantivyIndex() } String index_files_cache_path = this->index_files_manager->getTantivyIndexCacheDirectory(); - FFIBoolResult commit_result = ffi_index_writer_commit(index_files_cache_path); + TANTIVY::FFIBoolResult commit_result = TANTIVY::ffi_index_writer_commit(index_files_cache_path); if (commit_result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(commit_result.error.message)); @@ -849,7 +849,8 @@ rust::cxxbridge1::Vec TantivyIndexStore::singleTermQueryBitmap(Str if (!index_reader_status) getTantivyIndexReader(); - FFIVecU8Result result = ffi_query_term_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, term); + TANTIVY::FFIVecU8Result result + = TANTIVY::ffi_query_term_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, term); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -861,7 +862,8 @@ rust::cxxbridge1::Vec TantivyIndexStore::sentenceQueryBitmap(Strin if (!index_reader_status) getTantivyIndexReader(); - FFIVecU8Result result = ffi_query_sentence_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, sentence); + TANTIVY::FFIVecU8Result result + = TANTIVY::ffi_query_sentence_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, sentence); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -873,7 +875,8 @@ rust::cxxbridge1::Vec TantivyIndexStore::regexTermQueryBitmap(Stri if (!index_reader_status) getTantivyIndexReader(); - FFIVecU8Result result = ffi_regex_term_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, pattern); + TANTIVY::FFIVecU8Result result + = TANTIVY::ffi_regex_term_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, pattern); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -885,7 +888,8 @@ rust::cxxbridge1::Vec TantivyIndexStore::termsQueryBitmap(String c if (!index_reader_status) getTantivyIndexReader(); - FFIVecU8Result result = ffi_query_terms_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, terms); + TANTIVY::FFIVecU8Result result + = TANTIVY::ffi_query_terms_bitmap(this->index_files_manager->getTantivyIndexCacheDirectory(), column_name, terms); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -893,16 +897,18 @@ rust::cxxbridge1::Vec TantivyIndexStore::termsQueryBitmap(String c return result.result; } -rust::cxxbridge1::Vec TantivyIndexStore::bm25Search(String sentence, bool enable_nlq, bool operator_or, Statistics & statistics, size_t topk) +rust::cxxbridge1::Vec TantivyIndexStore::bm25Search( + String sentence, bool enable_nlq, bool operator_or, TANTIVY::Statistics & statistics, size_t topk, std::vector column_names) { DB::OpenTelemetry::SpanHolder span("TantivyIndexStore::bm25_search"); if (!index_reader_status) getTantivyIndexReader(); std::vector u8_alived_bitmap; - FFIVecRowIdWithScoreResult result = ffi_bm25_search( + TANTIVY::FFIVecRowIdWithScoreResult result = TANTIVY::ffi_bm25_search( this->index_files_manager->getTantivyIndexCacheDirectory(), sentence, + column_names, static_cast(topk), u8_alived_bitmap, false, @@ -917,16 +923,23 @@ rust::cxxbridge1::Vec TantivyIndexStore::bm25Search(String sente return result.result; } -rust::cxxbridge1::Vec TantivyIndexStore::bm25SearchWithFilter( - String sentence, bool enable_nlq, bool operator_or, Statistics & statistics, size_t topk, const std::vector & u8_alived_bitmap) +rust::cxxbridge1::Vec TantivyIndexStore::bm25SearchWithFilter( + String sentence, + bool enable_nlq, + bool operator_or, + TANTIVY::Statistics & statistics, + size_t topk, + const std::vector & u8_alived_bitmap, + std::vector column_names) { DB::OpenTelemetry::SpanHolder span("TantivyIndexStore::bm25_search_with_filter"); if (!index_reader_status) getTantivyIndexReader(); - FFIVecRowIdWithScoreResult result = ffi_bm25_search( + TANTIVY::FFIVecRowIdWithScoreResult result = TANTIVY::ffi_bm25_search( this->index_files_manager->getTantivyIndexCacheDirectory(), sentence, + column_names, static_cast(topk), u8_alived_bitmap, true, @@ -940,11 +953,13 @@ rust::cxxbridge1::Vec TantivyIndexStore::bm25SearchWithFilter( return result.result; } -rust::cxxbridge1::Vec TantivyIndexStore::getDocFreq(String sentence) + +rust::cxxbridge1::Vec TantivyIndexStore::getDocFreq(String sentence) { if (!index_reader_status) getTantivyIndexReader(); - FFIVecDocWithFreqResult result = ffi_get_doc_freq(this->index_files_manager->getTantivyIndexCacheDirectory(), sentence); + TANTIVY::FFIVecDocWithFreqResult result + = TANTIVY::ffi_get_doc_freq(this->index_files_manager->getTantivyIndexCacheDirectory(), sentence); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -956,7 +971,7 @@ UInt64 TantivyIndexStore::getTotalNumDocs() { if (!index_reader_status) getTantivyIndexReader(); - FFIU64Result result = ffi_get_total_num_docs(this->index_files_manager->getTantivyIndexCacheDirectory()); + TANTIVY::FFIU64Result result = TANTIVY::ffi_get_total_num_docs(this->index_files_manager->getTantivyIndexCacheDirectory()); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -964,11 +979,11 @@ UInt64 TantivyIndexStore::getTotalNumDocs() return result.result; } -rust::cxxbridge1::Vec TantivyIndexStore::getTotalNumTokens() +rust::cxxbridge1::Vec TantivyIndexStore::getTotalNumTokens() { if (!index_reader_status) getTantivyIndexReader(); - FFIFieldTokenNumsResult result = ffi_get_total_num_tokens(this->index_files_manager->getTantivyIndexCacheDirectory()); + TANTIVY::FFIFieldTokenNumsResult result = TANTIVY::ffi_get_total_num_tokens(this->index_files_manager->getTantivyIndexCacheDirectory()); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); @@ -980,7 +995,7 @@ UInt64 TantivyIndexStore::getIndexedDocsNum() { if (!index_reader_status) getTantivyIndexReader(); - FFIU64Result result = ffi_get_indexed_doc_counts(this->index_files_manager->getTantivyIndexCacheDirectory()); + TANTIVY::FFIU64Result result = TANTIVY::ffi_get_indexed_doc_counts(this->index_files_manager->getTantivyIndexCacheDirectory()); if (result.error.is_error) { throw DB::Exception(ErrorCodes::TANTIVY_SEARCH_INTERNAL_ERROR, "{}", std::string(result.error.message)); diff --git a/src/Storages/MergeTree/TantivyIndexStore.h b/src/Storages/MergeTree/TantivyIndexStore.h index 9fa856af64c..8e84b94a9db 100644 --- a/src/Storages/MergeTree/TantivyIndexStore.h +++ b/src/Storages/MergeTree/TantivyIndexStore.h @@ -165,16 +165,29 @@ class TantivyIndexStore rust::cxxbridge1::Vec termsQueryBitmap(String column_name, std::vector terms); /// For BM25Search and HybridSearch. If enable_nlq is true, use Natural Language Search. If enable_nlq is false, use Standard Search. - rust::cxxbridge1::Vec bm25Search(String sentence, bool enable_nlq, bool operator_or, Statistics & statistics, size_t topk); - rust::cxxbridge1::Vec - bm25SearchWithFilter(String sentence, bool enable_nlq, bool operator_or, Statistics & statistics, size_t topk, const std::vector & u8_alived_bitmap); + rust::cxxbridge1::Vec bm25Search( + String sentence, + bool enable_nlq, + bool operator_or, + TANTIVY::Statistics & statistics, + size_t topk, + std::vector column_names = {}); + + rust::cxxbridge1::Vec bm25SearchWithFilter( + String sentence, + bool enable_nlq, + bool operator_or, + TANTIVY::Statistics & statistics, + size_t topk, + const std::vector & u8_alived_bitmap, + std::vector column_names = {}); /// Get current part sentence doc_freq, sentence will be tokenized by tokenizer with each indexed column. - rust::cxxbridge1::Vec getDocFreq(String sentence); + rust::cxxbridge1::Vec getDocFreq(String sentence); /// Get current part total_num_docs, each column will have same total_num_docs. UInt64 getTotalNumDocs(); /// Get current part total_num_tokens, each column will have it's own total_num_tokens. - rust::cxxbridge1::Vec getTotalNumTokens(); + rust::cxxbridge1::Vec getTotalNumTokens(); /// Get the number of documents stored in the index file. UInt64 getIndexedDocsNum(); diff --git a/src/Storages/StorageDistributed.cpp b/src/Storages/StorageDistributed.cpp index 04aaaf1c696..0e711c25f87 100644 --- a/src/Storages/StorageDistributed.cpp +++ b/src/Storages/StorageDistributed.cpp @@ -93,6 +93,7 @@ #include #include #include +#include #include #include @@ -107,6 +108,8 @@ #include #include +#include + namespace fs = std::filesystem; @@ -1051,6 +1054,12 @@ void StorageDistributed::read( if (select_query->final() && local_context->getSettingsRef().allow_experimental_parallel_reading_from_replicas) throw Exception(ErrorCodes::ILLEGAL_FINAL, "Final modifier is not allowed together with parallel reading from replicas feature"); + // Distributed HybridSearch need to split query into vector search and text search + if (query_info.has_hybrid_search && query_info.hybrid_search_info && getShardCount() > 1) + { + return readHybridSearch(query_plan, storage_snapshot, query_info, local_context, processed_stage); + } + Block header; ASTPtr query_ast; @@ -1138,6 +1147,166 @@ void StorageDistributed::read( throw Exception(ErrorCodes::LOGICAL_ERROR, "Pipeline is not initialized"); } +void StorageDistributed::readHybridSearch( + QueryPlan & query_plan, + const StorageSnapshotPtr & storage_snapshot, + SelectQueryInfo & query_info, + ContextPtr local_context, + QueryProcessingStage::Enum processed_stage) +{ + auto settings = local_context->getSettingsRef(); + + if (settings.allow_experimental_analyzer) + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "HybridSearch is not supported with experimental analyzer"); + + Block header = InterpreterSelectQuery(query_info.query, local_context, SelectQueryOptions(processed_stage).analyze()).getSampleBlock(); + + /// Return directly (with correct header) if no shard to query. + if (query_info.getCluster()->getShardsInfo().empty()) + { + Pipe pipe(std::make_shared(header)); + auto read_from_pipe = std::make_unique(std::move(pipe)); + read_from_pipe->setStepDescription("Read from NullSource (Distributed)"); + query_plan.addStep(std::move(read_from_pipe)); + + return; + } + + ASTPtr query_ast_vector_search, query_ast_text_search; + splitHybridSearchAST( + query_info.query, + query_ast_vector_search, + query_ast_text_search, + query_info.hybrid_search_info->vector_scan_info->vector_scan_descs[0].direction, + query_info.hybrid_search_info->vector_scan_info->vector_scan_descs[0].topk, + query_info.hybrid_search_info->text_search_info->topk, + query_info.hybrid_search_info->text_search_info->enable_nlq, + query_info.hybrid_search_info->text_search_info->text_operator); + + StorageID main_table = StorageID::createEmpty(); + if (!remote_table_function_ptr) + main_table = StorageID{remote_database, remote_table}; + + const auto & snapshot_data = assert_cast(*storage_snapshot->data); + + ClusterProxy::AdditionalShardFilterGenerator additional_shard_filter_generator; + if (query_info.use_custom_key) + { + if (auto custom_key_ast = parseCustomKeyForTable(settings.parallel_replicas_custom_key, *local_context)) + { + if (query_info.getCluster()->getShardCount() == 1) + { + // we are reading from single shard with multiple replicas but didn't transform replicas + // into virtual shards with custom_key set + throw Exception(ErrorCodes::LOGICAL_ERROR, "Replicas weren't transformed into virtual shards"); + } + + additional_shard_filter_generator = + [&, custom_key_ast = std::move(custom_key_ast), shard_count = query_info.cluster->getShardCount()](uint64_t shard_num) -> ASTPtr + { + return getCustomKeyFilterForParallelReplica( + shard_count, shard_num - 1, custom_key_ast, settings.parallel_replicas_custom_key_filter_type, *this, local_context); + }; + } + } + + auto executeHybridSearch = [&](ASTPtr & query_ast, QueryPlan & query_plan_hybrid_search, UInt8 score_type) + { + Block header_query + = InterpreterSelectQuery(query_ast, local_context, SelectQueryOptions(processed_stage).analyze()).getSampleBlock(); + + const auto & modified_query_ast + = ClusterProxy::rewriteSelectQuery(local_context, query_ast, remote_database, remote_table, remote_table_function_ptr); + + ClusterProxy::SelectStreamFactory select_stream_factory + = ClusterProxy::SelectStreamFactory(header_query, snapshot_data.objects_by_shard, storage_snapshot, processed_stage); + + ClusterProxy::executeQuery( + query_plan_hybrid_search, + header_query, + processed_stage, + main_table, + remote_table_function_ptr, + select_stream_factory, + log, + modified_query_ast, + local_context, + query_info, + sharding_key_expr, + sharding_key_column_name, + query_info.cluster, + additional_shard_filter_generator); + + if (!query_plan_hybrid_search.isInitialized()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Distributed HybridSearch Pipeline is not initialized"); + + /// Add score_type column to mark the type of score (distance = 0, bm25 = 1) + { + ColumnWithTypeAndName column; + column.name = SCORE_TYPE_COLUMN.name; + column.type = SCORE_TYPE_COLUMN.type; + column.column = column.type->createColumnConst(0, Field(score_type)); + + auto adding_column_dag = ActionsDAG::makeAddingColumnActions(std::move(column)); + auto transform_step + = std::make_unique(query_plan_hybrid_search.getCurrentDataStream(), std::move(adding_column_dag)); + transform_step->setStepDescription("Add distributed hybrid search score_type column"); + query_plan_hybrid_search.addStep(std::move(transform_step)); + } + + /// Rename score column name: distance_func -> HybridSearch_func, textsearch_func -> HybridSearch_func + { + auto columns_with_type_and_name = query_plan_hybrid_search.getCurrentDataStream().header.getColumnsWithTypeAndName(); + for (auto & column : columns_with_type_and_name) + { + if (score_type == 0 && column.name == "distance_func") + column.name = "HybridSearch_func"; + else if (score_type == 1 && column.name == "textsearch_func") + column.name = "HybridSearch_func"; + } + + auto actions_dag = ActionsDAG::makeConvertingActions( + query_plan_hybrid_search.getCurrentDataStream().header.getColumnsWithTypeAndName(), + columns_with_type_and_name, + ActionsDAG::MatchColumnsMode::Position); + + auto rename_step = std::make_unique(query_plan_hybrid_search.getCurrentDataStream(), std::move(actions_dag)); + rename_step->setStepDescription("Distributed hybrid search rename score column name"); + query_plan_hybrid_search.addStep(std::move(rename_step)); + } + + /// Convert column name to the same name as the original query + { + auto actions_dag = ActionsDAG::makeConvertingActions( + query_plan_hybrid_search.getCurrentDataStream().header.getColumnsWithTypeAndName(), + header.getColumnsWithTypeAndName(), + ActionsDAG::MatchColumnsMode::Name); + + auto converting_step + = std::make_unique(query_plan_hybrid_search.getCurrentDataStream(), std::move(actions_dag)); + converting_step->setStepDescription("Distributed hybrid search convert column name"); + query_plan_hybrid_search.addStep(std::move(converting_step)); + } + }; + + QueryPlan query_plan_vector_search, query_plan_text_search; + executeHybridSearch(query_ast_vector_search, query_plan_vector_search, 0); + executeHybridSearch(query_ast_text_search, query_plan_text_search, 1); + + DataStreams input_streams{query_plan_vector_search.getCurrentDataStream(), query_plan_text_search.getCurrentDataStream()}; + + std::vector> plans; + plans.emplace_back(std::make_unique(std::move(query_plan_vector_search))); + plans.emplace_back(std::make_unique(std::move(query_plan_text_search))); + + auto union_step = std::make_unique(std::move(input_streams)); + query_plan.unitePlans(std::move(union_step), std::move(plans)); + + /// This is a bug, it is possible only when there is no shards to query, and this is handled earlier. + if (!query_plan.isInitialized()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Pipeline is not initialized"); +} + SinkToStoragePtr StorageDistributed::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context) { diff --git a/src/Storages/StorageDistributed.h b/src/Storages/StorageDistributed.h index d0aa1fcd01d..4d2bc65afcc 100644 --- a/src/Storages/StorageDistributed.h +++ b/src/Storages/StorageDistributed.h @@ -116,6 +116,13 @@ class StorageDistributed final : public IStorage, WithContext size_t /*max_block_size*/, size_t /*num_streams*/) override; + void readHybridSearch( + QueryPlan & query_plan, + const StorageSnapshotPtr & storage_snapshot, + SelectQueryInfo & query_info, + ContextPtr context, + QueryProcessingStage::Enum processed_stage); + bool supportsParallelInsert() const override { return true; } std::optional totalBytes(const Settings &) const override; diff --git a/src/Storages/StorageSnapshot.cpp b/src/Storages/StorageSnapshot.cpp index 0526b8000d1..56dfbf64b58 100644 --- a/src/Storages/StorageSnapshot.cpp +++ b/src/Storages/StorageSnapshot.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace DB { @@ -82,6 +83,10 @@ NamesAndTypesList StorageSnapshot::getColumnsByNames(const GetColumnsOptions & o { res.emplace_back(name, std::make_shared()); } + else if (name == SCORE_TYPE_COLUMN.name) + { + res.emplace_back(SCORE_TYPE_COLUMN); + } else if (isBatchDistance(name)) { auto id_type = std::make_shared(); @@ -175,6 +180,10 @@ Block StorageSnapshot::getSampleBlockForColumns(const Names & column_names, cons auto type = std::make_shared(); res.insert({type->createColumn(), type, column_name}); } + else if (column_name == SCORE_TYPE_COLUMN.name) + { + res.insert({SCORE_TYPE_COLUMN.type->createColumn(), SCORE_TYPE_COLUMN.type, SCORE_TYPE_COLUMN.name}); + } else if (isBatchDistance(column_name)) { auto id_type = std::make_shared(); diff --git a/src/VectorIndex/Common/BM25InfoInDataParts.h b/src/VectorIndex/Common/BM25InfoInDataParts.h index 4af1931672b..fc32cbd533b 100644 --- a/src/VectorIndex/Common/BM25InfoInDataParts.h +++ b/src/VectorIndex/Common/BM25InfoInDataParts.h @@ -28,10 +28,10 @@ namespace DB #if USE_TANTIVY_SEARCH -using RustVecDocWithFreq = rust::cxxbridge1::Vec; +using RustVecDocWithFreq = rust::cxxbridge1::Vec; /// Support tantivy index on multiple text columns -using VecTextColumnTokenNums = rust::cxxbridge1::Vec; +using VecTextColumnTokenNums = rust::cxxbridge1::Vec; struct BM25InfoInDataPart { diff --git a/src/VectorIndex/Interpreters/GetHybridSearchVisitor.h b/src/VectorIndex/Interpreters/GetHybridSearchVisitor.h index 3036a962428..38dea04a5e8 100644 --- a/src/VectorIndex/Interpreters/GetHybridSearchVisitor.h +++ b/src/VectorIndex/Interpreters/GetHybridSearchVisitor.h @@ -47,6 +47,10 @@ class GetHybridSearchMatcher std::vector vector_scan_funcs; std::vector text_search_func; std::vector hybrid_search_func; + + /// Save all vector scan functions including duplicated + /// Need to set flag in ASTFunction for multiple distances cases + std::vector all_multiple_vector_scan_funcs; }; static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) @@ -84,11 +88,18 @@ class GetHybridSearchMatcher if (isVectorScanFunc(node.name)) { auto full_name = getFullName(node); - if (data.uniq_names.count(full_name)) - return; if (data.assert_no_vector_scan) throw Exception(ErrorCodes::ILLEGAL_VECTOR_SCAN, "Vector Scan function {} is found {} in query", full_name, String(data.assert_no_vector_scan)); + + /// Save all existing distance funcs + if (isDistance(node.name)) + data.all_multiple_vector_scan_funcs.push_back(&node); + + /// Save duplicated distance functions once + if (data.uniq_names.count(full_name)) + return; + data.vector_scan_funcs.push_back(&node); data.uniq_names.insert(full_name); } diff --git a/src/VectorIndex/Processors/FusionSortingStep.cpp b/src/VectorIndex/Processors/FusionSortingStep.cpp new file mode 100644 index 00000000000..256a094d546 --- /dev/null +++ b/src/VectorIndex/Processors/FusionSortingStep.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (2024) MOQI SINGAPORE PTE. LTD. and/or its affiliates + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace CurrentMetrics +{ +extern const Metric TemporaryFilesForSort; +} + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +static ITransformingStep::Traits getTraits(size_t limit) +{ + return ITransformingStep::Traits{ + { + .returns_single_stream = true, + .preserves_number_of_streams = false, + .preserves_sorting = false, + }, + { + .preserves_number_of_rows = limit == 0, + }}; +} + +FusionSortingStep::FusionSortingStep( + const DataStream & input_stream, + SortDescription sort_description_, + UInt64 limit_, + UInt64 num_candidates_, + String fusion_type_, + UInt64 fusion_k_, + Float32 fusion_weight_, + Int8 distance_score_order_direction_) + : ITransformingStep(input_stream, input_stream.header, getTraits(limit_)) + , result_description(std::move(sort_description_)) + , limit(limit_) + , num_candidates(num_candidates_) + , fusion_type(fusion_type_) + , fusion_k(fusion_k_) + , fusion_weight(fusion_weight_) + , distance_score_order_direction(distance_score_order_direction_) +{ + output_stream->sort_description = result_description; + output_stream->sort_scope = DataStream::SortScope::Global; +} + +void FusionSortingStep::updateOutputStream() +{ + output_stream = createOutputStream(input_streams.front(), input_streams.front().header, getDataStreamTraits()); + output_stream->sort_description = result_description; + output_stream->sort_scope = DataStream::SortScope::Global; +} + + +void FusionSortingStep::transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) +{ + auto fusion_transform = std::make_shared( + pipeline.getHeader(), num_candidates, fusion_type, fusion_k, fusion_weight, distance_score_order_direction); + pipeline.addTransform(std::move(fusion_transform)); + + auto sort_transform = std::make_shared(pipeline.getHeader(), result_description); + pipeline.addTransform(std::move(sort_transform)); +} + +void FusionSortingStep::describeActions(FormatSettings & settings) const +{ + String prefix(settings.offset, ' '); + + settings.out << prefix << "HybridSearch Fusion description: " << fusion_type; + settings.out << '\n'; + + if (limit) + settings.out << prefix << "Limit " << limit << '\n'; +} + +void FusionSortingStep::describeActions(JSONBuilder::JSONMap & map) const +{ + map.add("HybridSearch Fusion Description", fusion_type); + + if (limit) + map.add("Limit", limit); +} + +} diff --git a/src/VectorIndex/Processors/FusionSortingStep.h b/src/VectorIndex/Processors/FusionSortingStep.h new file mode 100644 index 00000000000..2b4194df08e --- /dev/null +++ b/src/VectorIndex/Processors/FusionSortingStep.h @@ -0,0 +1,61 @@ +/* + * Copyright (2024) MOQI SINGAPORE PTE. LTD. and/or its affiliates + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace DB +{ + +/// Fusing and sorting data stream with distance or bm25 +class FusionSortingStep : public ITransformingStep +{ +public: + FusionSortingStep( + const DataStream & input_stream, + SortDescription sort_description_, + UInt64 limit_, + UInt64 num_candidates_, + String fusion_type_, + UInt64 fusion_k_, + Float32 fusion_weight_, + Int8 distance_score_order_direction_); + + String getName() const override { return "HybridSearchFusionSorting"; } + + void transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) override; + + void describeActions(JSONBuilder::JSONMap & map) const override; + void describeActions(FormatSettings & settings) const override; + +private: + void updateOutputStream() override; + + const SortDescription result_description; + UInt64 limit; + + UInt64 num_candidates = 0; + String fusion_type; + + UInt64 fusion_k; + Float32 fusion_weight; + + Int8 distance_score_order_direction; +}; + +} diff --git a/src/VectorIndex/Processors/HybridSearchFusionTransform.cpp b/src/VectorIndex/Processors/HybridSearchFusionTransform.cpp new file mode 100644 index 00000000000..6bdd0dbb635 --- /dev/null +++ b/src/VectorIndex/Processors/HybridSearchFusionTransform.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (2024) MOQI SINGAPORE PTE. LTD. and/or its affiliates + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace DB +{ + +Chunk HybridSearchFusionTransform::generate() +{ + if (chunks.empty()) + return {}; + + Chunk merged_chunk = std::move(chunks.front()); + chunks.pop(); + while (!chunks.empty()) + { + merged_chunk.append(std::move(chunks.front())); + chunks.pop(); + } + + auto merged_chunk_columns = merged_chunk.getColumns(); + auto total_rows = merged_chunk.getNumRows(); + + /// Row range is stored in the following format: {start_index, length} + std::pair distance_row_range = {0, 0}; + std::pair bm25_row_range = {0, 0}; + + auto & merged_score_column = assert_cast(*merged_chunk_columns[score_column_pos]); + auto & merged_score_type_column = assert_cast(*merged_chunk_columns[score_type_column_pos]); + + auto & merged_shard_num_column = assert_cast(*merged_chunk_columns[fusion_shard_num_pos]); + auto & merged_part_index_column = assert_cast(*merged_chunk_columns[fusion_part_index_pos]); + auto & merged_part_offset_column = assert_cast(*merged_chunk_columns[fusion_part_offset_pos]); + + /// Get the distance and bm25 score range + { + for (size_t row = 0; row < total_rows; ++row) + { + if (merged_score_type_column.get64(row) == 0) + bm25_row_range.first++; + else + break; + } + + bm25_row_range.second = std::min(total_rows - bm25_row_range.first, num_candidates); + + if (vector_scan_order_direction == -1) + { + distance_row_range.first = 0; + distance_row_range.second = std::min(bm25_row_range.first, num_candidates); + } + else if (vector_scan_order_direction == 1) + { + if (bm25_row_range.first >= num_candidates) + { + distance_row_range.first = bm25_row_range.first - num_candidates; + distance_row_range.second = num_candidates; + } + else if (bm25_row_range.first > 0) + { + distance_row_range.first = 0; + distance_row_range.second = bm25_row_range.first; + } + } + + LOG_DEBUG(log, "distance_row_range, start index: {}, count: {}", distance_row_range.first, distance_row_range.second); + LOG_DEBUG(log, "bm25_row_range, start index: {}, count: {}", bm25_row_range.first, bm25_row_range.second); + } + + /// Create distance and bm25 score dataset + ScoreWithPartIndexAndLabels bm25_score_dataset, distance_score_dataset; + for (size_t offset = 0; offset < distance_row_range.second; ++offset) + { + size_t row; + if (vector_scan_order_direction == -1) + row = distance_row_range.first + offset; + else + row = distance_row_range.first + distance_row_range.second - 1 - offset; + + distance_score_dataset.emplace_back( + merged_score_column.getFloat32(row), + merged_part_index_column.get64(row), + merged_part_offset_column.get64(row), + merged_shard_num_column.get64(row)); + + LOG_TRACE(log, "distance_score_dataset: {}", distance_score_dataset.back().dump()); + } + for (size_t offset = 0; offset < bm25_row_range.second; ++offset) + { + size_t row = bm25_row_range.first + offset; + bm25_score_dataset.emplace_back( + merged_score_column.getFloat32(row), + merged_part_index_column.get64(row), + merged_part_offset_column.get64(row), + merged_shard_num_column.get64(row)); + + LOG_TRACE(log, "bm25_score_dataset: {}", bm25_score_dataset.back().dump()); + } + + /// Calculate the fusion score + std::map, Float32> fusion_id_with_score; + if (fusion_type == HybridSearchFusionType::RSF) + { + RelativeScoreFusion( + fusion_id_with_score, distance_score_dataset, bm25_score_dataset, fusion_weight, vector_scan_order_direction, log); + } + else if (fusion_type == HybridSearchFusionType::RRF) + { + RankFusion(fusion_id_with_score, distance_score_dataset, bm25_score_dataset, fusion_k, log); + } + + auto fusion_result_columns = merged_chunk.cloneEmptyColumns(); + auto & result_score_column = assert_cast(*fusion_result_columns[score_column_pos]); + + auto & result_shard_num_column = assert_cast(*fusion_result_columns[fusion_shard_num_pos]); + auto & result_part_index_column = assert_cast(*fusion_result_columns[fusion_part_index_pos]); + auto & result_part_offset_column = assert_cast(*fusion_result_columns[fusion_part_offset_pos]); + + /// Copy bm25 score rows to fusion_result_columns + for (size_t position = 0; position < merged_chunk_columns.size(); ++position) + { + auto column = merged_chunk_columns[position]; + fusion_result_columns[position]->insertRangeFrom(*column, bm25_row_range.first, bm25_row_range.second); + } + + /// Replace the bm25 score with fusion score + for (size_t row = 0; row < bm25_row_range.second; ++row) + { + auto fusion_id = std::make_tuple( + static_cast(result_shard_num_column.get64(row)), + result_part_index_column.get64(row), + result_part_offset_column.get64(row)); + + result_score_column.getElement(row) = fusion_id_with_score[fusion_id]; + fusion_id_with_score.erase(fusion_id); + } + + /// Copy distance score rows to fusion_result_columns and replace the distance score with fusion score + for (size_t offset = 0; offset < distance_row_range.second; ++offset) + { + size_t row = distance_row_range.first + offset; + + auto fusion_id = std::make_tuple( + static_cast(merged_shard_num_column.get64(row)), + merged_part_index_column.get64(row), + merged_part_offset_column.get64(row)); + + if (fusion_id_with_score.find(fusion_id) != fusion_id_with_score.end()) + { + for (size_t position = 0; position < merged_chunk_columns.size(); ++position) + { + if (position == score_column_pos) + { + fusion_result_columns[position]->insert(fusion_id_with_score[fusion_id]); + } + else + { + fusion_result_columns[position]->insertFrom(*merged_chunk_columns[position], row); + } + } + } + } + + merged_chunk.setColumns(std::move(fusion_result_columns), result_score_column.size()); + return merged_chunk; +} + +} diff --git a/src/VectorIndex/Processors/HybridSearchFusionTransform.h b/src/VectorIndex/Processors/HybridSearchFusionTransform.h new file mode 100644 index 00000000000..1ea0f6e1b3e --- /dev/null +++ b/src/VectorIndex/Processors/HybridSearchFusionTransform.h @@ -0,0 +1,96 @@ +/* + * Copyright (2024) MOQI SINGAPORE PTE. LTD. and/or its affiliates + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +class HybridSearchFusionTransform final : public IAccumulatingTransform +{ +public: + enum HybridSearchFusionType + { + RSF, + RRF + }; + + String getName() const override { return "HybridSearchFusionTransform"; } + + explicit HybridSearchFusionTransform( + Block header, + UInt64 num_candidates_, + String fusion_type_, + UInt64 fusion_k_, + Float32 fusion_weight_, + Int8 vector_scan_order_direction_) + : IAccumulatingTransform(header, header) + , num_candidates(num_candidates_) + , fusion_k(fusion_k_) + , fusion_weight(fusion_weight_) + , vector_scan_order_direction(vector_scan_order_direction_) + { + score_column_pos = header.getPositionByName(HYBRID_SEARCH_SCORE_COLUMN_NAME); + score_type_column_pos = header.getPositionByName(SCORE_TYPE_COLUMN.name); + + fusion_shard_num_pos = header.getPositionByName("shardNum()"); + fusion_part_index_pos = header.getPositionByName("_part_index"); + fusion_part_offset_pos = header.getPositionByName("_part_offset"); + + if (isRelativeScoreFusion(fusion_type_)) + fusion_type = HybridSearchFusionType::RSF; + else if (isRankFusion(fusion_type_)) + fusion_type = HybridSearchFusionType::RRF; + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown HybridSearch fusion type: {}", fusion_type_); + } + + void consume(Chunk block) override { chunks.push(std::move(block)); } + + Chunk generate() override; + +private: + std::queue chunks; + + UInt64 num_candidates = 0; + HybridSearchFusionType fusion_type; + UInt64 fusion_k; + Float32 fusion_weight; + + /// Vector scan order direction in vector scan query + /// 1 - ascending, -1 - descending + Int8 vector_scan_order_direction; + + size_t score_column_pos; + size_t score_type_column_pos; + + /// Combine shard_num, part_index, part_offset into fusion_id + size_t fusion_shard_num_pos; + size_t fusion_part_index_pos; + size_t fusion_part_offset_pos; + + Poco::Logger * log = &Poco::Logger::get("HybridSearchFusionTransform"); +}; + +} diff --git a/src/VectorIndex/Processors/ReadWithHybridSearch.cpp b/src/VectorIndex/Processors/ReadWithHybridSearch.cpp index 05eb0cf8bc8..16ff1fcbe95 100644 --- a/src/VectorIndex/Processors/ReadWithHybridSearch.cpp +++ b/src/VectorIndex/Processors/ReadWithHybridSearch.cpp @@ -242,6 +242,19 @@ ReadWithHybridSearch::ReadWithHybridSearch( analyzed_result_ptr_, enable_parallel_reading) { + VectorScanInfoPtr vector_scan_info = nullptr; + if (query_info.vector_scan_info) + vector_scan_info = query_info.vector_scan_info; + else if (query_info.hybrid_search_info) + vector_scan_info = query_info.hybrid_search_info->vector_scan_info; + + if (vector_scan_info) + { + /// Mark for each vector scan description + size_t vector_scan_descs_size = vector_scan_info->vector_scan_descs.size(); + vec_support_two_stage_searches.resize(vector_scan_descs_size, false); + vec_num_reorders.resize(vector_scan_descs_size, 0); + } } void ReadWithHybridSearch::initializePipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) @@ -498,7 +511,7 @@ Pipe ReadWithHybridSearch::readFromParts( for (const auto & part : parts) { MergeTreeBaseSearchManagerPtr search_manager = nullptr; - search_manager = std::make_shared(metadata_for_reading, query_info.vector_scan_info, context); + search_manager = std::make_shared(metadata_for_reading, query_info.vector_scan_info, context, false); auto algorithm = std::make_unique( search_manager, @@ -547,7 +560,7 @@ ReadWithHybridSearch::HybridAnalysisResult ReadWithHybridSearch::selectTotalHybr parts_with_ranges, metadata_snapshot, query_info, - support_two_stage_search, + vec_support_two_stage_searches, #if USE_TANTIVY_SEARCH bm25_stats_in_table, #endif @@ -579,16 +592,60 @@ ReadWithHybridSearch::HybridAnalysisResult ReadWithHybridSearch::selectTotalHybr if(isFinal(query_info) && parts_with_vector_text_result.size() > 1) { LOG_DEBUG(log, "Perform final on search results from parts"); - performFinal(parts_with_vector_text_result, num_streams); + performFinal(parts_with_ranges, parts_with_vector_text_result, num_streams); } Poco::Logger * hybrid_log = &Poco::Logger::get(log_name); - /// Combine vector scan results from selected parts to get top-k result for vector scan. - ScoreWithPartIndexAndLabels vec_scan_topk_results; + std::unordered_map multiple_distances_topk_results_map; if (vector_scan_info) { - vec_scan_topk_results = MergeTreeBaseSearchManager::getTotalTopKVSResult(parts_with_vector_text_result, vector_scan_info, hybrid_log); + /// Check for each vector scan desc + size_t descs_size = vector_scan_info->vector_scan_descs.size(); + + std::vector distances_topk_results_vector; + distances_topk_results_vector.resize(descs_size); + + auto get_total_topk_on_single_col = [&](size_t desc_index) + { + /// Combine vector scan results from selected parts to get top-k result for vector scan. + ScoreWithPartIndexAndLabels vec_scan_topk_results; + + const auto & vector_scan_desc = vector_scan_info->vector_scan_descs[desc_index]; + vec_scan_topk_results = MergeTreeBaseSearchManager::getTotalTopKVSResult( + parts_with_vector_text_result, desc_index, vector_scan_desc, hybrid_log); + + /// Save the topk results for this vector scan + distances_topk_results_vector[desc_index] = vec_scan_topk_results; + }; + + size_t num_threads = std::min(num_streams, descs_size); + if (num_threads <= 1) + { + for (size_t desc_index = 0; desc_index < descs_size; ++desc_index) + get_total_topk_on_single_col(desc_index); + } + else + { + /// Parallel executing get total topk and possible two search stage + ThreadPool pool(CurrentMetrics::MergeTreeDataSelectHybridSearchThreads, CurrentMetrics::MergeTreeDataSelectHybridSearchThreadsActive, num_threads); + + for (size_t desc_index = 0; desc_index < descs_size; ++desc_index) + pool.scheduleOrThrowOnError([&, desc_index]() + { + get_total_topk_on_single_col(desc_index); + }); + + pool.wait(); + } + + /// Save vector scan results in to a map with result column name as key. + for (size_t i = 0; i < descs_size; ++i) + { + const auto & vector_scan_desc = vector_scan_info->vector_scan_descs[i]; + const String & result_column_name = vector_scan_desc.column_name; + multiple_distances_topk_results_map[result_column_name] = distances_topk_results_vector[i]; + } } /// Combine text search results from selected parts to get top-k result for text search. @@ -596,24 +653,46 @@ ReadWithHybridSearch::HybridAnalysisResult ReadWithHybridSearch::selectTotalHybr if (text_search_info) text_search_topk_results = MergeTreeBaseSearchManager::getTotalTopKTextResult(parts_with_vector_text_result, text_search_info, hybrid_log); - /// Do the fusion on the total top-k result of vector scan and text search from all selected parts, based on (part_index, label_id). - ScoreWithPartIndexAndLabels hybrid_topk_results; - if (hybrid) - hybrid_topk_results = MergeTreeHybridSearchManager::hybridSearch(vec_scan_topk_results, text_search_topk_results, query_info.hybrid_search_info, hybrid_log); + /// hybrid search or text search + if (text_search_info) + { + ScoreWithPartIndexAndLabels hybrid_topk_results; + + if (hybrid) + { + /// Only has one vector scan for hybrid search + ScoreWithPartIndexAndLabels vec_scan_topk_results; + if (multiple_distances_topk_results_map.size() == 1) + vec_scan_topk_results = multiple_distances_topk_results_map.begin()->second; + + /// Do the fusion on the total top-k result of vector scan and text search from all selected parts, based on (part_index, label_id). + hybrid_topk_results = MergeTreeHybridSearchManager::hybridSearch(vec_scan_topk_results, text_search_topk_results, query_info.hybrid_search_info, hybrid_log); + } + else + { + /// Only simple text search, save result to hybrid_topk_results + hybrid_topk_results = text_search_topk_results; + } + + /// Filter parts with final top-k hybrid result, and save hybrid result with belonged part + hybrid_result.parts_with_hybrid_and_ranges = MergeTreeHybridSearchManager::FilterPartsWithHybridResults( + parts_with_ranges, hybrid_topk_results, context->getSettingsRef(), hybrid_log); + } else { - /// Only simple text search or vector scan, save result to hybrid_topk_results - hybrid_topk_results = text_search_topk_results.size() > 0 ? text_search_topk_results : vec_scan_topk_results; + /// Support multiple distance functions + /// Filter parts with final top-k vector scan results from multiple distance funcs + hybrid_result.parts_with_hybrid_and_ranges = MergeTreeVSManager::FilterPartsWithManyVSResults( + parts_with_ranges, multiple_distances_topk_results_map, context->getSettingsRef(), hybrid_log); } - /// Filter parts with final top-k hybrid result, and save hybrid result with belonged part - hybrid_result.parts_with_hybrid_and_ranges = MergeTreeHybridSearchManager::FilterPartsWithHybridResults( - parts_with_vector_text_result, hybrid_topk_results, context->getSettingsRef(), hybrid_log); - return hybrid_result; } -void ReadWithHybridSearch::performFinal(VectorAndTextResultInDataParts & parts_with_vector_text_result, size_t num_streams) const +void ReadWithHybridSearch::performFinal( + const RangesInDataParts & parts_with_ranges, + VectorAndTextResultInDataParts & parts_with_vector_text_result, + size_t num_streams) const { OpenTelemetry::SpanHolder span("ReadWithHybridSearch::performFinal()"); const auto & settings = context->getSettingsRef(); @@ -629,8 +708,7 @@ void ReadWithHybridSearch::performFinal(VectorAndTextResultInDataParts & parts_w auto process_part = [&](size_t part_index) { auto & part_with_mix_results = parts_with_vector_text_result[part_index]; - const auto & part_with_ranges = part_with_mix_results.part_with_ranges; - String part_name = part_with_ranges.data_part->name; + String part_name = part_with_mix_results.data_part->name; /// Get all labels from vector scan result and/or text result auto labels_set = MergeTreeBaseSearchManager::getLabelsInSearchResults(part_with_mix_results, log); @@ -638,6 +716,7 @@ void ReadWithHybridSearch::performFinal(VectorAndTextResultInDataParts & parts_w return; /// Use labels in search result to filter mark ranges of part + const auto & part_with_ranges = parts_with_ranges[part_with_mix_results.part_index]; RangesInDataPart result_ranges(part_with_ranges); filterMarkRangesByLabels(part_with_ranges.data_part, settings, labels_set, result_ranges.ranges); @@ -803,8 +882,7 @@ void ReadWithHybridSearch::performFinal(VectorAndTextResultInDataParts & parts_w /// For a result in a part, remove it if not exists in final results of this part for (auto & part_with_mix_results : parts_with_vector_text_result) { - const auto & part_with_ranges = part_with_mix_results.part_with_ranges; - String part_name = part_with_ranges.data_part->name; + String part_name = part_with_mix_results.data_part->name; if (final_part_labels_map.contains(part_name) && !final_part_labels_map[part_name].empty()) { @@ -814,15 +892,15 @@ void ReadWithHybridSearch::performFinal(VectorAndTextResultInDataParts & parts_w else { /// part not exists in final results - part_with_mix_results.vector_scan_result = nullptr; part_with_mix_results.text_search_result = nullptr; + part_with_mix_results.vector_scan_results.clear(); } } } VectorAndTextResultInDataParts ReadWithHybridSearch::selectPartsBySecondStageVectorIndex( const VectorAndTextResultInDataParts & parts_with_candidates, - const VectorScanInfoPtr & vec_scan_info, + const VSDescription & vector_scan_desc, size_t num_streams) const { OpenTelemetry::SpanHolder span3("ReadWithHybridSearch::selectPartsBySecondStageVectorIndex()"); @@ -832,13 +910,13 @@ VectorAndTextResultInDataParts ReadWithHybridSearch::selectPartsBySecondStageVec /// Execute second stage vector scan in this part. auto process_part = [&](size_t part_index) { - auto & part_with_ranges_candidates = parts_with_candidates[part_index]; - auto & part_with_ranges = part_with_ranges_candidates.part_with_ranges; - auto & data_part = part_with_ranges.data_part; + auto & part_with_candidates = parts_with_candidates[part_index]; + auto & data_part = part_with_candidates.data_part; - VectorAndTextResultInDataPart vector_result(part_with_ranges); + VectorAndTextResultInDataPart vector_result(part_with_candidates.part_index, data_part); - vector_result.vector_scan_result = MergeTreeVSManager::executeSecondStageVectorScan(data_part, vec_scan_info, part_with_ranges_candidates.vector_scan_result); + auto two_stage_vector_scan_result = MergeTreeVSManager::executeSecondStageVectorScan(data_part, vector_scan_desc, part_with_candidates.vector_scan_results[0]); + vector_result.vector_scan_results.emplace_back(two_stage_vector_scan_result); parts_with_vector_result[part_index] = std::move(vector_result); }; @@ -914,22 +992,33 @@ Pipe ReadWithHybridSearch::readFromParts( for (const auto & part_with_hybrid : parts_with_hybrid_ranges) { /// Already have search result for data part, save it to search_manager. - if (!part_with_hybrid.search_result || !part_with_hybrid.search_result->computed) - continue; + /// Support multiple distance functions, check multiple_vector_scan_results + if (query_info.vector_scan_info) + { + if (part_with_hybrid.multiple_vector_scan_results.empty()) + continue; + } + else + { + /// Check search_result for text and hybrid search + if (!part_with_hybrid.search_result || !part_with_hybrid.search_result->computed) + continue; + } MergeTreeBaseSearchManagerPtr search_manager = nullptr; + auto & part_with_ranges = part_with_hybrid.part_with_ranges; if (query_info.hybrid_search_info) search_manager = std::make_shared(part_with_hybrid.search_result, query_info.hybrid_search_info); else if (query_info.vector_scan_info) - search_manager = std::make_shared(part_with_hybrid.search_result, query_info.vector_scan_info); + search_manager = std::make_shared(part_with_hybrid.multiple_vector_scan_results, query_info.vector_scan_info); else if (query_info.text_search_info) search_manager = std::make_shared(part_with_hybrid.search_result, query_info.text_search_info); if (!search_manager) { /// Should not happen - LOG_WARNING(log, "Failed to initialize search manager for part {}", part_with_hybrid.data_part->name); + LOG_WARNING(log, "Failed to initialize search manager for part {}", part_with_ranges.data_part->name); continue; } @@ -939,13 +1028,13 @@ Pipe ReadWithHybridSearch::readFromParts( requested_num_streams, data, storage_snapshot, - part_with_hybrid.data_part, - part_with_hybrid.alter_conversions, + part_with_ranges.data_part, + part_with_ranges.alter_conversions, max_block_size, preferred_block_size_bytes, preferred_max_column_in_block_size_bytes, required_columns, - part_with_hybrid.ranges, + part_with_ranges.ranges, use_uncompressed_cache, prewhere_info, actions_settings, diff --git a/src/VectorIndex/Processors/ReadWithHybridSearch.h b/src/VectorIndex/Processors/ReadWithHybridSearch.h index d653929b6eb..66bd1b2a2ba 100644 --- a/src/VectorIndex/Processors/ReadWithHybridSearch.h +++ b/src/VectorIndex/Processors/ReadWithHybridSearch.h @@ -65,8 +65,10 @@ class ReadWithHybridSearch final : public ReadFromMergeTree private: - bool support_two_stage_search = false; /// True if two stage search is used. - [[maybe_unused]] UInt64 num_reorder = 0; /// number of candidates for first stage search + /// Support multiple distance functions + /// The size of following two vectors are equal to the size of vector scan descriptions + std::vector vec_support_two_stage_searches; /// True if two stage search is supported + [[maybe_unused]] std::vector vec_num_reorders; /// number of candidates for first stage search ReadWithHybridSearch::HybridAnalysisResult getHybridSearchResult(const RangesInDataParts & parts) const; @@ -79,7 +81,7 @@ class ReadWithHybridSearch final : public ReadFromMergeTree /// Get accurate distance value for candidates by second stage vector index in belonged part VectorAndTextResultInDataParts selectPartsBySecondStageVectorIndex( const VectorAndTextResultInDataParts & parts_with_candidates, - const VectorScanInfoPtr & vec_scan_info, + const VSDescription & vector_scan_desc, size_t num_streams) const; Pipe readFromParts( @@ -95,10 +97,13 @@ class ReadWithHybridSearch final : public ReadFromMergeTree #if USE_TANTIVY_SEARCH void getStatisticForTextSearch(); - Statistics bm25_stats_in_table; /// total bm25 info from all parts in a table + TANTIVY::Statistics bm25_stats_in_table; /// total bm25 info from all parts in a table #endif - void performFinal(VectorAndTextResultInDataParts & parts_with_vector_text_result, size_t num_streams) const; + void performFinal( + const RangesInDataParts & parts_with_ranges, + VectorAndTextResultInDataParts & parts_with_vector_text_result, + size_t num_streams) const; }; } diff --git a/src/VectorIndex/Storages/HybridSearchResult.h b/src/VectorIndex/Storages/HybridSearchResult.h index eff24cabf9d..2bd244c95ae 100644 --- a/src/VectorIndex/Storages/HybridSearchResult.h +++ b/src/VectorIndex/Storages/HybridSearchResult.h @@ -30,58 +30,68 @@ struct CommonSearchResult { bool computed = false; MutableColumns result_columns; - std::vector was_result_processed; /// Mark if the result was processed or not. + String name; /// vector scan function (distance) result column name }; using CommonSearchResultPtr = std::shared_ptr; using VectorScanResultPtr = std::shared_ptr; +using ManyVectorScanResults = std::vector; + using TextSearchResultPtr = std::shared_ptr; using HybridSearchResultPtr = std::shared_ptr; /// Extend RangesInDataPart to include search result for hybrid, vector scan or text search struct SearchResultAndRangesInDataPart { - DataPartPtr data_part; - AlterConversionsPtr alter_conversions; - size_t part_index_in_query; - MarkRanges ranges; - CommonSearchResultPtr search_result; + RangesInDataPart part_with_ranges; + + /// Valid for text, hybrid search + CommonSearchResultPtr search_result = nullptr; + + /// Support multiple distance functions + /// Valid for vector scan + ManyVectorScanResults multiple_vector_scan_results = {}; SearchResultAndRangesInDataPart() = default; SearchResultAndRangesInDataPart( - const DataPartPtr & data_part_, - const AlterConversionsPtr & alter_conversions_, - const size_t part_index_in_query_, - const MarkRanges & ranges_ = MarkRanges{}, + const RangesInDataPart & part_with_ranges_, const CommonSearchResultPtr & search_result_ = nullptr) - : data_part{data_part_} - , alter_conversions{alter_conversions_} - , part_index_in_query{part_index_in_query_} - , ranges{ranges_} + : part_with_ranges{part_with_ranges_} , search_result{search_result_} {} + + SearchResultAndRangesInDataPart( + const RangesInDataPart & part_with_ranges_, + const ManyVectorScanResults & multiple_vector_scan_results_ = {}) + : part_with_ranges{part_with_ranges_} + , multiple_vector_scan_results{multiple_vector_scan_results_} + {} }; using SearchResultAndRangesInDataParts = std::vector; +/// Internal structure of intermediate results /// Save vector scan and/or text search result for a part struct VectorAndTextResultInDataPart { - RangesInDataPart part_with_ranges; - VectorScanResultPtr vector_scan_result; - TextSearchResultPtr text_search_result; + /// Other data part info can be accessed by part_index from parts_with_ranges + size_t part_index; + DataPartPtr data_part; + + /// Valid for text and hybrid search + TextSearchResultPtr text_search_result = nullptr; + + /// Support multiple distance functions + /// Use the first element for hybrid search and second stage of two-stage vector search + ManyVectorScanResults vector_scan_results = {}; VectorAndTextResultInDataPart() = default; - VectorAndTextResultInDataPart( - const RangesInDataPart & part_with_ranges_, - const VectorScanResultPtr & vector_scan_result_ = nullptr, - const TextSearchResultPtr & text_search_result_ = nullptr) - : part_with_ranges{part_with_ranges_} - , vector_scan_result{vector_scan_result_} - , text_search_result{text_search_result_} + VectorAndTextResultInDataPart(const size_t part_index_, const DataPartPtr & data_part_) + : part_index{part_index_} + , data_part{data_part_} {} }; @@ -91,14 +101,28 @@ using VectorAndTextResultInDataParts = std::vector; diff --git a/src/VectorIndex/Storages/MergeTreeBaseSearchManager.cpp b/src/VectorIndex/Storages/MergeTreeBaseSearchManager.cpp index 164a8d4474f..1837e6c0e2c 100644 --- a/src/VectorIndex/Storages/MergeTreeBaseSearchManager.cpp +++ b/src/VectorIndex/Storages/MergeTreeBaseSearchManager.cpp @@ -25,7 +25,6 @@ void MergeTreeBaseSearchManager::mergeSearchResultImpl( size_t & read_rows, const ReadRanges & read_ranges, CommonSearchResultPtr tmp_result, - const Search::DenseBitmapPtr filter, const ColumnUInt64 * part_offset) { Poco::Logger * log = &Poco::Logger::get("MergeTreeBaseSearchManager"); @@ -40,8 +39,8 @@ void MergeTreeBaseSearchManager::mergeSearchResultImpl( } /// Initialize was_result_processed - if (tmp_result->was_result_processed.size() == 0) - tmp_result->was_result_processed.assign(label_column->size(), false); + if (was_result_processed.size() == 0) + was_result_processed.assign(label_column->size(), false); auto final_distance_column = DataTypeFloat32().createColumn(); @@ -53,149 +52,103 @@ void MergeTreeBaseSearchManager::mergeSearchResultImpl( final_result.emplace_back(col->cloneEmpty()); } - if (filter) + if (part_offset == nullptr) { - size_t current_column_pos = 0; - int range_index = 0; - size_t start_pos = read_ranges[range_index].start_row; - size_t offset = 0; - for (size_t i = 0; i < filter->get_size(); ++i) + LOG_DEBUG(log, "No part offset"); + size_t start_pos = 0; + size_t end_pos = 0; + size_t prev_row_num = 0; + + for (auto & read_range : read_ranges) { - if (offset >= read_ranges[range_index].row_num) - { - ++range_index; - start_pos = read_ranges[range_index].start_row; - offset = 0; - } - if (filter->unsafe_test(i)) + start_pos = read_range.start_row; + end_pos = read_range.start_row + read_range.row_num; + for (size_t ind = 0; ind < label_column->size(); ++ind) { - /// for each vector search result, try to find if there is one with label equals to row id. - for (size_t ind = 0; ind < label_column->size(); ++ind) - { - /// Skip if this label has already processed - if (tmp_result->was_result_processed[ind]) - continue; + if (was_result_processed[ind]) + continue; - if (label_column->getUInt(ind) == start_pos + offset) + const UInt64 label_value = label_column->getUInt(ind); + if (label_value >= start_pos && label_value < end_pos) + { + const size_t index_of_arr = label_value - start_pos + prev_row_num; + for (size_t i = 0; i < final_result.size(); ++i) { - /// LOG_DEBUG(log, "merge result: ind: {}, current_column_pos: {}, filter_id: {}", ind, current_column_pos, i + start_offset); - /// for each result column - for (size_t col = 0; col < final_result.size(); ++col) - { - Field field; - pre_result[col]->get(current_column_pos, field); - final_result[col]->insert(field); - } - final_distance_column->insert(distance_column->getFloat32(ind)); - - tmp_result->was_result_processed[ind] = true; + Field field; + pre_result[i]->get(index_of_arr, field); + final_result[i]->insert(field); } + + final_distance_column->insert(distance_column->getFloat32(ind)); + + was_result_processed[ind] = true; } - ++current_column_pos; } - ++offset; + prev_row_num += read_range.row_num; } } - else + else if (part_offset->size() > 0) { - LOG_DEBUG(log, "No filter statement"); - if (part_offset == nullptr) - { - size_t start_pos = 0; - size_t end_pos = 0; - size_t prev_row_num = 0; + LOG_DEBUG(log, "Get part offset"); - for (auto & read_range : read_ranges) - { - start_pos = read_range.start_row; - end_pos = read_range.start_row + read_range.row_num; - for (size_t ind = 0; ind < label_column->size(); ++ind) - { - if (tmp_result->was_result_processed[ind]) - continue; - - const UInt64 label_value = label_column->getUInt(ind); - if (label_value >= start_pos && label_value < end_pos) - { - const size_t index_of_arr = label_value - start_pos + prev_row_num; - for (size_t i = 0; i < final_result.size(); ++i) - { - Field field; - pre_result[i]->get(index_of_arr, field); - final_result[i]->insert(field); - } + /// When lightweight delete applied, the rowid in the label column cannot be used as index of pre_result. + /// Match the rowid in the value of label col and the value of part_offset to find the correct index. + const ColumnUInt64::Container & offset_raw_value = part_offset->getData(); + size_t part_offset_size = part_offset->size(); - final_distance_column->insert(distance_column->getFloat32(ind)); + size_t start_pos = 0; + size_t end_pos = 0; - tmp_result->was_result_processed[ind] = true; - } - } - prev_row_num += read_range.row_num; - } - } - else if (part_offset->size() > 0) + for (auto & read_range : read_ranges) { - LOG_DEBUG(log, "Get part offset"); + start_pos = read_range.start_row; + end_pos = read_range.start_row + read_range.row_num; - /// When lightweight delete applied, the rowid in the label column cannot be used as index of pre_result. - /// Match the rowid in the value of label col and the value of part_offset to find the correct index. - const ColumnUInt64::Container & offset_raw_value = part_offset->getData(); - size_t part_offset_size = part_offset->size(); - - size_t start_pos = 0; - size_t end_pos = 0; - - for (auto & read_range : read_ranges) + for (size_t ind = 0; ind < label_column->size(); ++ind) { - start_pos = read_range.start_row; - end_pos = read_range.start_row + read_range.row_num; + if (was_result_processed[ind]) + continue; - for (size_t ind = 0; ind < label_column->size(); ++ind) - { - if (tmp_result->was_result_processed[ind]) - continue; + const UInt64 label_value = label_column->getUInt(ind); - const UInt64 label_value = label_column->getUInt(ind); + /// Check if label_value inside this read range + if (label_value < start_pos || (label_value >= end_pos)) + continue; - /// Check if label_value inside this read range - if (label_value < start_pos || (label_value >= end_pos)) - continue; + /// read range doesn't consider LWD, hence start_row and row_num in read range cannot be used in this case. + int low = 0; + int mid; + if (part_offset_size - 1 > static_cast(std::numeric_limits::max())) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of part_offset_size exceeds the limit of int data type"); + int high = static_cast(part_offset_size - 1); - /// read range doesn't consider LWD, hence start_row and row_num in read range cannot be used in this case. - int low = 0; - int mid; - if (part_offset_size - 1 > static_cast(std::numeric_limits::max())) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of part_offset_size exceeds the limit of int data type"); - int high = static_cast(part_offset_size - 1); + /// label_value (row id) = part_offset. + /// We can use binary search to quickly locate part_offset for current label. + while (low <= high) + { + mid = low + (high - low) / 2; - /// label_value (row id) = part_offset. - /// We can use binary search to quickly locate part_offset for current label. - while (low <= high) + if (label_value == offset_raw_value[mid]) { - mid = low + (high - low) / 2; - - if (label_value == offset_raw_value[mid]) + /// Use the index of part_offset to locate other columns in pre_result and fill final_result. + for (size_t i = 0; i < final_result.size(); ++i) { - /// Use the index of part_offset to locate other columns in pre_result and fill final_result. - for (size_t i = 0; i < final_result.size(); ++i) - { - Field field; - pre_result[i]->get(mid, field); - final_result[i]->insert(field); - } + Field field; + pre_result[i]->get(mid, field); + final_result[i]->insert(field); + } - final_distance_column->insert(distance_column->getFloat32(ind)); + final_distance_column->insert(distance_column->getFloat32(ind)); - tmp_result->was_result_processed[ind] = true; + was_result_processed[ind] = true; - /// break from binary search loop - break; - } - else if (label_value > offset_raw_value[mid]) - low = mid + 1; - else - high = mid - 1; + /// break from binary search loop + break; } + else if (label_value > offset_raw_value[mid]) + low = mid + 1; + else + high = mid - 1; } } } @@ -212,15 +165,16 @@ void MergeTreeBaseSearchManager::mergeSearchResultImpl( ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopKVSResult( const VectorAndTextResultInDataParts & vector_results, - const VectorScanInfoPtr & vector_scan_info, + const size_t vec_res_index, + const VSDescription & vector_scan_desc, Poco::Logger * log) { - const auto vec_scan_desc = vector_scan_info->vector_scan_descs[0]; + bool desc_direction = vector_scan_desc.direction == -1; + int top_k = vector_scan_desc.topk > 0 ? vector_scan_desc.topk : VectorIndex::DEFAULT_TOPK; - bool desc_direction = vec_scan_desc.direction == -1; - int top_k = vec_scan_desc.topk > 0 ? vec_scan_desc.topk : VectorIndex::DEFAULT_TOPK; + LOG_TRACE(log, "Total top k vector scan results for {} with size {}", vector_scan_desc.column_name, top_k); - return getTotalTopSearchResultImpl(vector_results, static_cast(top_k), desc_direction, log, true); + return getTotalTopSearchResultImpl(vector_results, static_cast(top_k), desc_direction, log, true, vec_res_index); } ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopKTextResult( @@ -229,20 +183,25 @@ ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopKTextResult( Poco::Logger * log) { int top_k = text_info->topk; + + LOG_TRACE(log, "Total top k text results with size {}", top_k); + return getTotalTopSearchResultImpl(text_results, static_cast(top_k), true, log, false); } ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalCandidateVSResult( const VectorAndTextResultInDataParts & parts_with_vector_text_result, - const VectorScanInfoPtr & vector_scan_info, + const size_t vec_res_index, + const VSDescription & vector_scan_desc, const UInt64 & num_reorder, Poco::Logger * log) { - const auto vec_scan_desc = vector_scan_info->vector_scan_descs[0]; - bool desc_direction = vec_scan_desc.direction == -1; + bool desc_direction = vector_scan_desc.direction == -1; + + LOG_TRACE(log, "Total candidate vector scan results for {} with size {}", vector_scan_desc.column_name, num_reorder); /// Get top num_reorder candidates: part index + label + score - return getTotalTopSearchResultImpl(parts_with_vector_text_result, num_reorder, desc_direction, log, true); + return getTotalTopSearchResultImpl(parts_with_vector_text_result, num_reorder, desc_direction, log, true, vec_res_index); } ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopSearchResultImpl( @@ -250,7 +209,8 @@ ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopSearchResultI const UInt64 & top_k, const bool & desc_direction, Poco::Logger * log, - const bool need_vector) + const bool need_vector, + const size_t vec_res_index) { String search_name = need_vector ? "vector scan" : "text search"; @@ -262,11 +222,15 @@ ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopSearchResultI /// part + top-k result in part CommonSearchResultPtr search_result; if (need_vector) - search_result = mix_results_in_part.vector_scan_result; + { + /// Support multiple distance functions + if (vec_res_index < mix_results_in_part.vector_scan_results.size()) + search_result = mix_results_in_part.vector_scan_results[vec_res_index]; + } else search_result = mix_results_in_part.text_search_result; - const auto & part_index = mix_results_in_part.part_with_ranges.part_index_in_query; + const auto & part_index = mix_results_in_part.part_index; if (search_result && search_result->computed) { @@ -306,6 +270,9 @@ ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopSearchResultI { for (auto rit = sorted_score_with_index_labels.rbegin(); rit != sorted_score_with_index_labels.rend(); ++rit) { + LOG_TRACE(log, "part_index={}, label_id={}, score={}", rit->second.part_index, + rit->second.label_id, rit->second.score); + result_score_with_index_labels.emplace_back(rit->second); count++; @@ -317,6 +284,9 @@ ScoreWithPartIndexAndLabels MergeTreeBaseSearchManager::getTotalTopSearchResultI { for (const auto & [_, score_with_index_label] : sorted_score_with_index_labels) { + LOG_TRACE(log, "part_index={}, label_id={}, score={}", score_with_index_label.part_index, + score_with_index_label.label_id, score_with_index_label.score); + result_score_with_index_labels.emplace_back(score_with_index_label); count++; @@ -335,15 +305,18 @@ std::set MergeTreeBaseSearchManager::getLabelsInSearchResults( OpenTelemetry::SpanHolder span("MergeTreeBaseSearchManager::getLabelsInSearchResults()"); std::set label_ids; - VectorScanResultPtr vector_result = mix_results.vector_scan_result; TextSearchResultPtr text_result = mix_results.text_search_result; - if (vector_result && vector_result->computed) - getLabelsInSearchResult(label_ids, vector_result, log); - if (text_result && text_result->computed) getLabelsInSearchResult(label_ids, text_result, log); + /// Support multiple distance functions + for (const auto & vector_scan_result : mix_results.vector_scan_results) + { + if (vector_scan_result && vector_scan_result->computed) + getLabelsInSearchResult(label_ids, vector_scan_result, log); + } + return label_ids; } @@ -377,23 +350,26 @@ void MergeTreeBaseSearchManager::filterSearchResultsByFinalLabels( std::set & label_ids, Poco::Logger * log) { - LOG_DEBUG(log, "filterSearchResultsByFinalLabels: part = {}", mix_results.part_with_ranges.data_part->name); + LOG_DEBUG(log, "filterSearchResultsByFinalLabels: part = {}", mix_results.data_part->name); if (label_ids.empty()) { - mix_results.vector_scan_result = nullptr; mix_results.text_search_result = nullptr; + mix_results.vector_scan_results.clear(); return; } - VectorScanResultPtr vector_result = mix_results.vector_scan_result; TextSearchResultPtr text_result = mix_results.text_search_result; - if (vector_result && vector_result->computed) - mix_results.vector_scan_result = filterSearchResultByFinalLabels(vector_result, label_ids, log); - if (text_result && text_result->computed) mix_results.text_search_result = filterSearchResultByFinalLabels(text_result, label_ids, log); + + /// Support multiple distance functions + for (auto & vector_scan_result : mix_results.vector_scan_results) + { + if (vector_scan_result && vector_scan_result->computed) + vector_scan_result = filterSearchResultByFinalLabels(vector_scan_result, label_ids, log); + } } CommonSearchResultPtr MergeTreeBaseSearchManager::filterSearchResultByFinalLabels( @@ -441,6 +417,9 @@ CommonSearchResultPtr MergeTreeBaseSearchManager::filterSearchResultByFinalLabel final_result->computed = true; final_result->result_columns[0] = std::move(res_label_column); final_result->result_columns[1] = std::move(res_score_column); + + /// Add column name + final_result->name = pre_search_result->name; } return final_result; diff --git a/src/VectorIndex/Storages/MergeTreeBaseSearchManager.h b/src/VectorIndex/Storages/MergeTreeBaseSearchManager.h index bbc85544775..916df81c5bf 100644 --- a/src/VectorIndex/Storages/MergeTreeBaseSearchManager.h +++ b/src/VectorIndex/Storages/MergeTreeBaseSearchManager.h @@ -38,11 +38,12 @@ class MergeTreeBaseSearchManager using ReadRanges = MergeTreeRangeReader::ReadResult::ReadRangesInfo; MergeTreeBaseSearchManager( - StorageMetadataPtr metadata_, ContextPtr context_, const String & function_col_name_) + StorageMetadataPtr metadata_, ContextPtr context_, const String & function_col_name_ = "") : metadata(metadata_) , context(context_) - , function_col_name(function_col_name_) { + if (!function_col_name_.empty()) + search_func_cols_names.emplace_back(function_col_name_); } virtual ~MergeTreeBaseSearchManager() = default; @@ -62,7 +63,6 @@ class MergeTreeBaseSearchManager Columns & /* pre_result */, size_t & /* read_rows */, const ReadRanges & /* read_ranges */, - const Search::DenseBitmapPtr /* filter */, const ColumnUInt64 * /* part_offset */) {} /// True if search result is present and computed flag is set to true. @@ -70,14 +70,16 @@ class MergeTreeBaseSearchManager virtual CommonSearchResultPtr getSearchResult() { return nullptr; } - String getFuncColumnName() { return function_col_name; } + /// Support multiple distance functions + Names getSearchFuncColumnNames() { return search_func_cols_names; } const Settings & getSettings() { return context->getSettingsRef(); } /// Get top-k vector scan result among all selected parts static ScoreWithPartIndexAndLabels getTotalTopKVSResult( const VectorAndTextResultInDataParts & vector_results, - const VectorScanInfoPtr & vector_scan_info, + const size_t vec_res_index, + const VSDescription & vector_scan_desc, Poco::Logger * log); static ScoreWithPartIndexAndLabels getTotalTopKTextResult( @@ -88,7 +90,8 @@ class MergeTreeBaseSearchManager /// Get num_reorder candidate vector result among all selected parts for two stage search static ScoreWithPartIndexAndLabels getTotalCandidateVSResult( const VectorAndTextResultInDataParts & parts_with_vector_text_result, - const VectorScanInfoPtr & vector_scan_info, + const size_t vec_res_index, + const VSDescription & vector_scan_desc, const UInt64 & num_reorder, Poco::Logger * log); @@ -105,7 +108,9 @@ class MergeTreeBaseSearchManager StorageMetadataPtr metadata; ContextPtr context; - String function_col_name; /// search function column name + Names search_func_cols_names; /// names of search function columns + + std::vector was_result_processed; /// Mark if the result was processed or not. /// Merge search result with pre_result from read part /// Add score/distance of the row id to the corresponding row @@ -114,7 +119,6 @@ class MergeTreeBaseSearchManager size_t & read_rows, const ReadRanges & read_ranges = ReadRanges(), CommonSearchResultPtr tmp_result = nullptr, - const Search::DenseBitmapPtr filter = nullptr, const ColumnUInt64 * part_offset = nullptr); /// Get top-k vector or text search result among all selected parts @@ -125,7 +129,8 @@ class MergeTreeBaseSearchManager const UInt64 & top_k, const bool & desc_direction, Poco::Logger * log, - const bool need_vector); + const bool need_vector, + const size_t vec_res_index = 0); /// Get label_ids in search result and save in label_ids set. static void getLabelsInSearchResult( diff --git a/src/VectorIndex/Storages/MergeTreeHybridSearchManager.cpp b/src/VectorIndex/Storages/MergeTreeHybridSearchManager.cpp index 1a45d39ac3c..7d4332b2552 100644 --- a/src/VectorIndex/Storages/MergeTreeHybridSearchManager.cpp +++ b/src/VectorIndex/Storages/MergeTreeHybridSearchManager.cpp @@ -21,13 +21,14 @@ #include #include - +#include #include #include #include #include #include +#include #include #include @@ -35,6 +36,12 @@ #include +namespace CurrentMetrics +{ + extern const Metric MergeTreeDataSelectHybridSearchThreads; + extern const Metric MergeTreeDataSelectHybridSearchThreadsActive; +} + namespace DB { @@ -107,8 +114,9 @@ ScoreWithPartIndexAndLabels MergeTreeHybridSearchManager::hybridSearch( /// Get fusion type from hybrid_info String fusion_type = hybrid_info->fusion_type; - /// Store result after fusion. (, score) - std::map, Float32> part_index_labels_with_fusion_score; + /// Store result after fusion. (, score) + /// As for single-shard hybrid search, shard_num is always 0. + std::map, Float32> part_index_labels_with_fusion_score; /// Relative Sore Fusion if (isRelativeScoreFusion(fusion_type)) @@ -127,7 +135,7 @@ ScoreWithPartIndexAndLabels MergeTreeHybridSearchManager::hybridSearch( /// Assume fusion_k is handled by ExpressionAnalyzer int fusion_k = hybrid_info->fusion_k <= 0 ? 60 : hybrid_info->fusion_k; - RankFusion(part_index_labels_with_fusion_score, vec_scan_result_with_part_index, text_search_result_with_part_index, fusion_k); + RankFusion(part_index_labels_with_fusion_score, vec_scan_result_with_part_index, text_search_result_with_part_index, fusion_k, log); } /// Sort hybrid search result based on fusion score and return top-k rows. @@ -135,8 +143,14 @@ ScoreWithPartIndexAndLabels MergeTreeHybridSearchManager::hybridSearch( std::multimap, std::greater> sorted_fusion_scores_with_part_index_label; for (const auto & [part_index_label_id, fusion_score] : part_index_labels_with_fusion_score) { - LOG_TEST(log, "part_index={}, label_id={}, hybrid_score={}", part_index_label_id.first, part_index_label_id.second, fusion_score); - sorted_fusion_scores_with_part_index_label.emplace(fusion_score, part_index_label_id); + LOG_TEST( + log, + "part_index={}, label_id={}, hybrid_score={}", + std::get<1>(part_index_label_id), + std::get<2>(part_index_label_id), + fusion_score); + + sorted_fusion_scores_with_part_index_label.emplace(fusion_score, std::make_pair(std::get<1>(part_index_label_id), std::get<2>(part_index_label_id))); } /// Save topk part indexes, label ids and fusion score into hybrid_result. @@ -156,140 +170,8 @@ ScoreWithPartIndexAndLabels MergeTreeHybridSearchManager::hybridSearch( return hybrid_result; } -void MergeTreeHybridSearchManager::RelativeScoreFusion( - std::map, Float32> & part_index_labels_with_convex_score, - const ScoreWithPartIndexAndLabels & vec_scan_result_with_part_index, - const ScoreWithPartIndexAndLabels & text_search_result_with_part_index, - const float weight_of_text, - const int vector_scan_direction, - Poco::Logger * log) -{ - /// min-max normalization on text search score - std::vector norm_score; - norm_score.reserve(text_search_result_with_part_index.size()); - computeMinMaxNormScore(text_search_result_with_part_index, norm_score, log); - - LOG_TEST(log, "text bm25 scores:"); - /// final score = norm-BM25 * w + (1-w) * norm-distance - for (size_t idx = 0; idx < text_search_result_with_part_index.size(); idx++) - { - const auto & text_score_with_part_index = text_search_result_with_part_index[idx]; - auto part_index_label_id = std::make_pair(text_score_with_part_index.part_index, text_score_with_part_index.label_id); - - LOG_TEST(log, "part_index={}, label_id={}, origin_score={}, norm_score={}", - text_score_with_part_index.part_index, text_score_with_part_index.label_id, text_score_with_part_index.score, norm_score[idx]); - - /// label_ids from text search are unique - part_index_labels_with_convex_score[part_index_label_id] = norm_score[idx] * weight_of_text; - } - - /// min-max normalization on text search score - norm_score.clear(); - computeMinMaxNormScore(vec_scan_result_with_part_index, norm_score, log); - - LOG_TEST(log, "distance scores:"); - /// The Relative score fusion with distance score depends on the metric type. - for (size_t idx = 0; idx < vec_scan_result_with_part_index.size(); idx++) - { - const auto & vec_score_with_part_index = vec_scan_result_with_part_index[idx]; - auto part_index_label_id = std::make_pair(vec_score_with_part_index.part_index, vec_score_with_part_index.label_id); - - LOG_TEST(log, "part_index={}, label_id={}, origin_score={}, norm_score={}", - vec_score_with_part_index.part_index, vec_score_with_part_index.label_id, vec_score_with_part_index.score, norm_score[idx]); - - Float32 fusion_score = 0; - - /// 1 - ascending, -1 - descending - if (vector_scan_direction == -1) - fusion_score = norm_score[idx] * (1 - weight_of_text); - else - fusion_score = (1 - weight_of_text) * (1 - norm_score[idx]); - - /// Insert or update score for label_id - part_index_labels_with_convex_score[part_index_label_id] += fusion_score; - } -} - -void MergeTreeHybridSearchManager::computeMinMaxNormScore( - const ScoreWithPartIndexAndLabels & search_result_with_part_index, - std::vector & norm_score_vec, - Poco::Logger * log) -{ - const auto result_size = search_result_with_part_index.size(); - if (result_size == 0) - { - LOG_DEBUG(log, "search result is empty"); - return; - } - - /// Here assume the scores in score column are ordered. - /// Thus the min score and max score are the first and last. - Float32 min_score, max_score, min_max_scale; - min_score = search_result_with_part_index[0].score; - max_score = search_result_with_part_index[result_size - 1].score; - - /// When min_score = max_score, norm_score = 1.0; - if (min_score == max_score) - { - LOG_DEBUG(log, "max_score and min_score are equal"); - for (size_t idx = 0; idx < result_size; idx++) - norm_score_vec.emplace_back(1.0); - - return; - } - else if (min_score > max_score) /// DESC - { - Float32 tmp_score = min_score; - min_score = max_score; - max_score = tmp_score; - } - - min_max_scale = max_score - min_score; - - /// min-max normalization score = (score - min_score) / (max_score - min_score) - for (size_t idx = 0; idx < result_size; idx++) - { - Float32 norm_score = (search_result_with_part_index[idx].score - min_score) / min_max_scale; - norm_score_vec.emplace_back(norm_score); - } -} - -void MergeTreeHybridSearchManager::RankFusion( - std::map, Float32> & part_index_labels_with_ranked_score, - const ScoreWithPartIndexAndLabels & vec_scan_result_with_part_index, - const ScoreWithPartIndexAndLabels & text_search_result_with_part_index, - int k) -{ - /// Ranked score = 1.0 / (k + rank(label_id)) - size_t idx = 0; - for (const auto & score_with_part_index_label : vec_scan_result_with_part_index) - { - Float32 rank_score = 1.0f / (k + idx + 1); - auto part_index_label = std::make_pair(score_with_part_index_label.part_index, score_with_part_index_label.label_id); - - /// For new (part_index, label_id) pair, map will insert. - /// part_index_labels_with_ranked_score map saved the fusion score for a (part_index, label_id) pair. - part_index_labels_with_ranked_score[part_index_label] += rank_score; - - idx++; - } - - idx = 0; - for (const auto & score_with_part_index_label : text_search_result_with_part_index) - { - Float32 rank_score = 1.0f / (k + idx + 1); - auto part_index_label = std::make_pair(score_with_part_index_label.part_index, score_with_part_index_label.label_id); - - /// For new (part_index, label_id) pair, map will insert. - /// part_index_labels_with_ranked_score map saved the fusion score for a (part_index, label_id) pair. - part_index_labels_with_ranked_score[part_index_label] += rank_score; - - idx++; - } -} - SearchResultAndRangesInDataParts MergeTreeHybridSearchManager::FilterPartsWithHybridResults( - const VectorAndTextResultInDataParts & parts_with_vector_text_result, + const RangesInDataParts & parts_with_ranges, const ScoreWithPartIndexAndLabels & hybrid_result_with_part_index, const Settings & settings, Poco::Logger * log) @@ -303,15 +185,16 @@ SearchResultAndRangesInDataParts MergeTreeHybridSearchManager::FilterPartsWithHy part_index_merged_map[part_index].emplace_back(score_with_part_index_label); } + size_t parts_with_ranges_size = parts_with_ranges.size(); SearchResultAndRangesInDataParts parts_with_ranges_hybrid_result; + parts_with_ranges_hybrid_result.resize(parts_with_ranges_size); /// Filter data part with part index in hybrid search and label ids for mark ranges - for (const auto & mix_results_in_part : parts_with_vector_text_result) + auto filter_part_with_results = [&](size_t part_index) { - const auto & part_with_ranges = mix_results_in_part.part_with_ranges; - size_t part_index = part_with_ranges.part_index_in_query; + const auto & part_with_ranges = parts_with_ranges[part_index]; - /// Check if part_index exists in map + /// Check if part_index for this part_with_ranges exists in map if (part_index_merged_map.contains(part_index)) { /// Found data part @@ -325,16 +208,52 @@ SearchResultAndRangesInDataParts MergeTreeHybridSearchManager::FilterPartsWithHy if (!mark_ranges_for_part.empty()) { - parts_with_ranges_hybrid_result.emplace_back( - part_with_ranges.data_part, - part_with_ranges.alter_conversions, - part_index, - std::move(mark_ranges_for_part), - tmp_hybrid_search_result); + RangesInDataPart ranges(part_with_ranges.data_part, + part_with_ranges.alter_conversions, + part_with_ranges.part_index_in_query, + std::move(mark_ranges_for_part)); + + SearchResultAndRangesInDataPart result_with_ranges(std::move(ranges), tmp_hybrid_search_result); + parts_with_ranges_hybrid_result[part_index] = std::move(result_with_ranges); } } + }; + + size_t num_threads = std::min(settings.max_threads, parts_with_ranges_size); + if (num_threads <= 1) + { + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) + filter_part_with_results(part_index); + } + else + { + /// Parallel executing filter parts_in_ranges with total top-k results + ThreadPool pool(CurrentMetrics::MergeTreeDataSelectHybridSearchThreads, CurrentMetrics::MergeTreeDataSelectHybridSearchThreadsActive, num_threads); + + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) + pool.scheduleOrThrowOnError([&, part_index]() + { + filter_part_with_results(part_index); + }); + + pool.wait(); + } + + /// Skip empty search result + size_t next_part = 0; + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) + { + auto & part_with_results = parts_with_ranges_hybrid_result[part_index]; + if (!part_with_results.search_result) + continue; + + if (next_part != part_index) + std::swap(parts_with_ranges_hybrid_result[next_part], part_with_results); + ++next_part; } + parts_with_ranges_hybrid_result.resize(next_part); + return parts_with_ranges_hybrid_result; } @@ -342,10 +261,9 @@ void MergeTreeHybridSearchManager::mergeResult( Columns & pre_result, size_t & read_rows, const ReadRanges & read_ranges, - const Search::DenseBitmapPtr filter, const ColumnUInt64 * part_offset) { - mergeSearchResultImpl(pre_result, read_rows, read_ranges, hybrid_search_result, filter, part_offset); + mergeSearchResultImpl(pre_result, read_rows, read_ranges, hybrid_search_result, part_offset); } } diff --git a/src/VectorIndex/Storages/MergeTreeHybridSearchManager.h b/src/VectorIndex/Storages/MergeTreeHybridSearchManager.h index cbf286d16ed..f7d31be69b5 100644 --- a/src/VectorIndex/Storages/MergeTreeHybridSearchManager.h +++ b/src/VectorIndex/Storages/MergeTreeHybridSearchManager.h @@ -84,7 +84,6 @@ class MergeTreeHybridSearchManager : public MergeTreeBaseSearchManager Columns & pre_result, size_t & read_rows, const ReadRanges & read_ranges, - const Search::DenseBitmapPtr filter = nullptr, const ColumnUInt64 * part_offset = nullptr) override; bool preComputed() override @@ -115,7 +114,7 @@ class MergeTreeHybridSearchManager : public MergeTreeBaseSearchManager } #if USE_TANTIVY_SEARCH - void setBM25Stats(const Statistics & bm25_stats_in_table_) + void setBM25Stats(const TANTIVY::Statistics & bm25_stats_in_table_) { if (text_search_manager) text_search_manager->setBM25Stats(bm25_stats_in_table_); @@ -132,7 +131,7 @@ class MergeTreeHybridSearchManager : public MergeTreeBaseSearchManager /// Filter parts using total top-k hybrid search result /// For every part, select mark ranges to read, also save hybrid result static SearchResultAndRangesInDataParts FilterPartsWithHybridResults( - const VectorAndTextResultInDataParts & parts_with_vector_text_result, + const RangesInDataParts & parts_with_ranges, const ScoreWithPartIndexAndLabels & hybrid_result_with_part_index, const Settings & settings, Poco::Logger * log); @@ -147,27 +146,6 @@ class MergeTreeHybridSearchManager : public MergeTreeBaseSearchManager MergeTreeTextSearchManagerPtr text_search_manager = nullptr; Poco::Logger * log = &Poco::Logger::get("MergeTreeHybridSearchManager"); - - static void RelativeScoreFusion( - std::map, Float32> & part_index_labels_with_convex_score, - const ScoreWithPartIndexAndLabels & vec_scan_result_with_part_index, - const ScoreWithPartIndexAndLabels & text_search_result_with_part_index, - const float weight_of_text, - const int vector_scan_direction, - Poco::Logger * log); - - static void computeMinMaxNormScore( - const ScoreWithPartIndexAndLabels & search_result_with_part_index, - std::vector & norm_score_vec, - Poco::Logger * log); - - /// Compute reciprocal rank score for a (part_index, label id) pair - /// The map part_index_labels_with_ranked_score stores the sum of rank score for a (part_index, label id) pair - static void RankFusion( - std::map, Float32> & part_index_labels_with_ranked_score, - const ScoreWithPartIndexAndLabels & vec_scan_result_with_part_index, - const ScoreWithPartIndexAndLabels & text_search_result_with_part_index, - int k); }; using MergeTreeHybridSearchManagerPtr = std::shared_ptr; diff --git a/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.cpp b/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.cpp index 26b65ac7b76..89e20c2e81d 100644 --- a/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.cpp +++ b/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -48,7 +49,7 @@ namespace ErrorCodes extern const int MEMORY_LIMIT_EXCEEDED; } -/// Check if only select primary key column and vector search/text search/hybrid search functions. +/// Check if only select primary key column, _part_offset and vector search/text search/hybrid search functions. static bool isHybridSearchByPk(const std::vector & pk_col_names, const std::vector & read_col_names) { size_t pk_col_nums = pk_col_names.size(); @@ -63,7 +64,8 @@ static bool isHybridSearchByPk(const std::vector & pk_col_names, const s bool match = true; for (const auto & read_col_name : read_col_names) { - if ((read_col_name == pk_col_name) || isHybridSearchFunc(read_col_name) || isScoreColumnName(read_col_name)) + if ((read_col_name == pk_col_name) || read_col_name == "_part_offset" + || isHybridSearchFunc(read_col_name) || isScoreColumnName(read_col_name)) continue; else { @@ -458,46 +460,64 @@ IMergeTreeSelectAlgorithm::BlockAndProgress MergeTreeSelectWithHybridSearchProce if (read_result.num_rows == 0) return {Block(), read_result.num_rows, num_read_rows, num_read_bytes}; - /// Remove distance_func column from read_result.columns, it will be added by vector search. + /// Support multiple distance functions + /// Remove distance_func columns from read_result.columns, it will be added by vector search. Columns ordered_columns; - String vector_scan_col_name; + Names vector_scan_cols_names; + size_t cols_size_in_sample_block = sample_block.columns(); + size_t cols_size_except_search_cols = cols_size_in_sample_block; + if (base_search_manager) { - ordered_columns.reserve(sample_block.columns() - 1); - vector_scan_col_name = base_search_manager->getFuncColumnName(); + vector_scan_cols_names = base_search_manager->getSearchFuncColumnNames(); + cols_size_except_search_cols = cols_size_in_sample_block - vector_scan_cols_names.size(); + ordered_columns.reserve(cols_size_except_search_cols); + + /// Throw exception if vector_scan_cols_names is empty + if (vector_scan_cols_names.empty()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Failed to find any search result column name, this should not happen"); } else - ordered_columns.reserve(sample_block.columns()); + ordered_columns.reserve(cols_size_in_sample_block); + + /// sample block columns may be: + /// part without LWD: table columns + distance_func columns + non_const_virtual_columns + /// part with LWD: non_const_virtual_columns + table columns + distance_func columns + /// All distances are put at the end of ordered_columns, the order of distances is same as in vector_scan_descriptions. + /// This vector is a map of the index of ordered_columns to the sample block + std::vector orig_pos_in_sample_block; + orig_pos_in_sample_block.resize(cols_size_in_sample_block); + size_t ordered_index = 0; - size_t which_cut = 0; - bool found_search_func_col = false; for (size_t ps = 0; ps < sample_block.columns(); ++ps) { auto & col_name = sample_block.getByPosition(ps).name; - /// TODO: not add distance column to header_without_virtual_columns - if (col_name == vector_scan_col_name) + /// Check if distance_func columns + bool is_search_func = false; + for (size_t i = 0; i < vector_scan_cols_names.size(); ++i) { - which_cut = ps; - found_search_func_col = true; - continue; + if (col_name == vector_scan_cols_names[i]) + { + orig_pos_in_sample_block[cols_size_except_search_cols+i] = ps; + is_search_func = true; + break; + } } + /// No need to put search func cols + if (is_search_func) + continue; + ordered_columns.emplace_back(std::move(read_result.columns[ps])); + orig_pos_in_sample_block[ordered_index] = ps; + ordered_index++; /// Copy _part_offset column if (col_name == "_part_offset") - { part_offset = typeid_cast(ordered_columns.back().get()); - } } - if (!found_search_func_col) - throw Exception( - ErrorCodes::LOGICAL_ERROR, - "Failed to find column name '{}' for search function in sample block during read", - vector_scan_col_name); - auto read_end_time = std::chrono::system_clock::now(); LOG_DEBUG(log, "Read time: {}", std::chrono::duration_cast(read_end_time - read_start_time).count()); @@ -509,11 +529,13 @@ IMergeTreeSelectAlgorithm::BlockAndProgress MergeTreeSelectWithHybridSearchProce base_search_manager->mergeResult( ordered_columns, read_result.num_rows, - read_ranges, nullptr, part_offset); + read_ranges, part_offset); } const size_t final_result_num_rows = read_result.num_rows; + LOG_DEBUG(log, "mergeResult() finished with result rows: {}", final_result_num_rows); + Block res_block; /// Add prewhere column name to avoid prewhere_column not found error @@ -532,34 +554,16 @@ IMergeTreeSelectAlgorithm::BlockAndProgress MergeTreeSelectWithHybridSearchProce res_block.insert(std::move(prewhere_col)); } + /// ordered_columns: non-search functions, search functions cols + /// Use the map orig_pos_in_sample_block to get column name and type from sample block for (size_t i = 0; i < ordered_columns.size(); ++i) { + size_t pos_in_sample = orig_pos_in_sample_block[i]; + ColumnWithTypeAndName ctn; ctn.column = ordered_columns[i]; - - if (i < ordered_columns.size() - 1) - { - size_t src_index = i >= which_cut ? i+1 : i; - ctn.type = sample_block.getByPosition(src_index).type; - ctn.name = sample_block.getByPosition(src_index).name; - } - else - { - ctn.name = vector_scan_col_name; - if (isBatchDistance(vector_scan_col_name)) - { - // the result of batch search, it's type is Tuple(UInt32, Float32) - DataTypes data_types; - data_types.emplace_back(std::make_shared()); - data_types.emplace_back(std::make_shared()); - ctn.type = std::make_shared(data_types); - } - else - { - // the result of single search, it's type is Float32 - ctn.type = std::make_shared(); - } - } + ctn.type = sample_block.getByPosition(pos_in_sample).type; + ctn.name = sample_block.getByPosition(pos_in_sample).name; res_block.insert(std::move(ctn)); } @@ -667,7 +671,8 @@ IMergeTreeSelectAlgorithm::BlockAndProgress MergeTreeSelectWithHybridSearchProce LOG_DEBUG(log, "Fetch from primary key cache size = {}", tmp_result_columns[0]->size()); - /// Get _part_offset if exists. + /// Get _part_offset if exists + bool part_offset_exists_in_result = false; if (mutable_part_offset_col) { /// _part_offset column exists in original select columns @@ -675,8 +680,11 @@ IMergeTreeSelectAlgorithm::BlockAndProgress MergeTreeSelectWithHybridSearchProce { tmp_result_columns.emplace_back(std::move(mutable_part_offset_col)); part_offset = typeid_cast(tmp_result_columns.back().get()); + + /// Need to adjust order in results + part_offset_exists_in_result = true; } - else + else /// No need to put result columns, it's just used in mergeResult() for LWD part_offset = typeid_cast(mutable_part_offset_col.get()); } @@ -688,19 +696,42 @@ IMergeTreeSelectAlgorithm::BlockAndProgress MergeTreeSelectWithHybridSearchProce tmp_result_columns, /// _Inout_ result_row_num, /// _Out_ read_ranges, - nullptr, part_offset); + /// header_without_const_virtual_columns: pk columns + distance columns + non const virtual columns(_part_offset) + /// tmp_result_columns: pk columns + non const virtual columns(_part_offset) + distance columns Columns result_columns; + result_columns.resize(tmp_result_columns.size()); + + /// _part_offset column exists in original select columns + if (part_offset_exists_in_result) + { + /// Exchange order of non const virtual column and distance columns + size_t distances_size = base_search_manager->getSearchFuncColumnNames().size(); - if(!need_remove_part_offset){ + /// Throw exception if distances_size is empty + if (distances_size == 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "[PKCache] Failed to find any search result column name, this should not happen"); + + for (size_t i = 0; i < tmp_result_columns.size(); ++i) + { + size_t pos_in_result; + if (i < pk_col_size) + pos_in_result = i; + else if (i == pk_col_size) + pos_in_result = i + distances_size; + else + pos_in_result = i - 1; /// non const virtual column has ONE column: _part_offset + + result_columns[pos_in_result] = tmp_result_columns[i]; + } + } + else + { + /// No _part_offset column possibly added for LWD in tmp result columns result_columns = tmp_result_columns; - }else{ - result_columns.emplace_back(tmp_result_columns[0]); - result_columns.emplace_back(tmp_result_columns.back()); } - task->mark_ranges.clear(); if (result_row_num > 0) { @@ -775,7 +806,7 @@ void MergeTreeSelectWithHybridSearchProcessor::executeSearch(MarkRanges mark_ran executeSearch(base_search_manager, storage, storage_snapshot, data_part, alter_conversions, max_block_size_rows, preferred_block_size_bytes, preferred_max_column_in_block_size_bytes, mark_ranges, - prewhere_info_copy, reader_settings, use_uncompressed_cache, context, max_streams_for_prewhere); + prewhere_info_copy, reader_settings, context, max_streams_for_prewhere); } void MergeTreeSelectWithHybridSearchProcessor::executeSearch( @@ -790,7 +821,6 @@ void MergeTreeSelectWithHybridSearchProcessor::executeSearch( MarkRanges mark_ranges, const PrewhereInfoPtr & prewhere_info_copy, const MergeTreeReaderSettings & reader_settings_, - bool use_uncompressed_cache_, ContextPtr context_, size_t max_streams) { @@ -809,7 +839,7 @@ void MergeTreeSelectWithHybridSearchProcessor::executeSearch( /// 2 perform vector scan based on part_offsets auto filter = performPrefilter(mark_ranges, prewhere_info_copy, storage_, storage_snapshot_, data_part_, alter_conversions_, max_block_size, preferred_block_size_bytes_, - preferred_max_column_in_block_size_bytes_, reader_settings_, use_uncompressed_cache_, + preferred_max_column_in_block_size_bytes_, reader_settings_, context_, max_streams); ReadRanges read_ranges; @@ -820,6 +850,91 @@ void MergeTreeSelectWithHybridSearchProcessor::executeSearch( } } +namespace +{ + +struct PartRangesReadInfo +{ + size_t sum_marks = 0; + size_t total_rows = 0; + size_t index_granularity_bytes = 0; + size_t min_marks_for_concurrent_read = 0; + size_t min_rows_for_concurrent_read = 0; + + bool use_uncompressed_cache = false; + bool is_adaptive = false; + + PartRangesReadInfo( + const MergeTreeData::DataPartPtr & data_part, + const MarkRanges & mark_ranges, + const Settings & settings, + const MergeTreeSettings & data_settings) + { + /// Count marks to read for the part. + total_rows = data_part->index_granularity.getRowsCountInRanges(mark_ranges); + sum_marks = mark_ranges.getNumberOfMarks(); + + is_adaptive = data_part->index_granularity_info.mark_type.adaptive; + + if (is_adaptive) + index_granularity_bytes = data_settings.index_granularity_bytes; + + auto part_on_remote_disk = data_part->isStoredOnRemoteDisk(); + + size_t min_bytes_for_concurrent_read; + if (part_on_remote_disk) + { + min_rows_for_concurrent_read = settings.merge_tree_min_rows_for_concurrent_read_for_remote_filesystem; + min_bytes_for_concurrent_read = settings.merge_tree_min_bytes_for_concurrent_read_for_remote_filesystem; + } + else + { + min_rows_for_concurrent_read = settings.merge_tree_min_rows_for_concurrent_read; + min_bytes_for_concurrent_read = settings.merge_tree_min_bytes_for_concurrent_read; + } + + min_marks_for_concurrent_read = MergeTreeDataSelectExecutor::minMarksForConcurrentRead( + min_rows_for_concurrent_read, min_bytes_for_concurrent_read, + data_settings.index_granularity, index_granularity_bytes, sum_marks); + + /// Don't adjust this value based on sum_marks and max_marks_to_use_cache as in ReadFromMergeTree + use_uncompressed_cache = settings.use_uncompressed_cache; + } +}; + +template +VIBitmapPtr getFilterFromPipeline(size_t num_rows, Pipe & pipe) +{ + QueryPipelineBuilder builder; + builder.init(std::move(pipe)); + + QueryPipeline filter_pipeline = QueryPipelineBuilder::getPipeline(std::move(builder)); + + /// Use different pipeline executors + PullingExecutor filter_executor(filter_pipeline); + + Block block; + VIBitmapPtr filter = std::make_shared(num_rows); + { + OpenTelemetry::SpanHolder span_pipe("performPrefilter()::getFilterFromPipeline()"); + while (filter_executor.pull(block)) + { + if (block) + { + const PaddedPODArray & col_data = checkAndGetColumn(*block.getByName("_part_offset").column)->getData(); + for (size_t i = 0; i < block.rows(); ++i) + { + filter->set(col_data[i]); + } + } + } + } + + return filter; +} + +} + VIBitmapPtr MergeTreeSelectWithHybridSearchProcessor::performPrefilter( MarkRanges mark_ranges, const PrewhereInfoPtr & prewhere_info_copy, @@ -831,7 +946,6 @@ VIBitmapPtr MergeTreeSelectWithHybridSearchProcessor::performPrefilter( UInt64 preferred_block_size_bytes_, UInt64 preferred_max_column_in_block_size_bytes_, const MergeTreeReaderSettings & reader_settings_, - bool use_uncompressed_cache_, ContextPtr context_, size_t max_streams) { @@ -869,49 +983,76 @@ VIBitmapPtr MergeTreeSelectWithHybridSearchProcessor::performPrefilter( } } - /// Only one part - RangesInDataParts parts_with_ranges; - parts_with_ranges.emplace_back(data_part_, std::make_shared(), 0, mark_ranges); - - /// spreadMarkRangesAmongStreams() + /// Check if parallel reading mark ranges among streams is enabled + bool enable_parallel_reading = false; const auto & settings = context_->getSettingsRef(); const auto data_settings = storage_.getSettings(); + size_t num_rows = data_part_->rows_count; - size_t sum_marks = data_part_->getMarksCount(); - size_t min_marks_for_concurrent_read = 0; - min_marks_for_concurrent_read = MergeTreeDataSelectExecutor::minMarksForConcurrentRead( - settings.merge_tree_min_rows_for_concurrent_read, settings.merge_tree_min_bytes_for_concurrent_read, - data_settings->index_granularity, data_settings->index_granularity_bytes, sum_marks); + PartRangesReadInfo info(data_part_, mark_ranges, settings, *data_settings); /// max streams for performing prewhere size_t num_streams = max_streams; + + LOG_DEBUG(&Poco::Logger::get("performPreFilter"), "max_streams = {}, original min_marks_for_concurrent_read = {}, sum_marks = {}, total_rows = {}, min_rows_for_concurrent_read = {}", + max_streams, info.min_marks_for_concurrent_read, info.sum_marks, info.total_rows, info.min_rows_for_concurrent_read); + + /// Enable parallel when num_streams > 1 if (num_streams > 1) { - /// Reduce the number of num_streams if the data is small. - if (sum_marks < num_streams * min_marks_for_concurrent_read && parts_with_ranges.size() < num_streams) - num_streams = std::max((sum_marks + min_marks_for_concurrent_read - 1) / min_marks_for_concurrent_read, parts_with_ranges.size()); + if (settings.parallel_reading_prefilter_option == 2) + enable_parallel_reading = true; + else if (settings.parallel_reading_prefilter_option == 1) + { + /// Adaptively enable parallel reading based on mark ranges and row count + /// Reduce the number of num_streams if the data is small. + if (info.sum_marks < num_streams * info.min_marks_for_concurrent_read) + { + const size_t prev_num_streams = num_streams; + num_streams = (info.sum_marks + info.min_marks_for_concurrent_read - 1) / info.min_marks_for_concurrent_read; + const size_t increase_num_streams_ratio = std::min(prev_num_streams / num_streams, info.min_marks_for_concurrent_read / 8); + if (increase_num_streams_ratio > 1) + { + num_streams = num_streams * increase_num_streams_ratio; + info.min_marks_for_concurrent_read = (info.sum_marks + num_streams - 1) / num_streams; + } + } + else if (info.total_rows < num_streams * info.min_rows_for_concurrent_read) + { + num_streams = (info.total_rows + info.min_rows_for_concurrent_read - 1) / info.min_rows_for_concurrent_read; + const size_t new_min_marks_for_concurrent_read = (info.sum_marks + num_streams -1 ) / num_streams; + if (new_min_marks_for_concurrent_read > info.min_marks_for_concurrent_read) + info.min_marks_for_concurrent_read = new_min_marks_for_concurrent_read; + } + + if (num_streams > 1) + enable_parallel_reading = true; + } } - Pipe pipe; + LOG_DEBUG(&Poco::Logger::get("performPreFilter"), "num_streams = {}, min_marks_for_concurrent_read = {}", max_streams, info.min_marks_for_concurrent_read); - if (num_streams > 1) + /// Read in multiple threads will use Async pulling executor + if (enable_parallel_reading) { + /// spreadMarkRangesAmongStreams() + /// Only one part + RangesInDataParts parts_with_ranges; + parts_with_ranges.emplace_back(data_part_, std::make_shared(), 0, mark_ranges); Pipes pipes; - if (max_block_size && !storage_.canUseAdaptiveGranularity()) + if (max_block_size && !info.is_adaptive) { - size_t fixed_index_granularity = storage_.getSettings()->index_granularity; - min_marks_for_concurrent_read = (min_marks_for_concurrent_read * fixed_index_granularity + max_block_size - 1) + size_t fixed_index_granularity = data_settings->index_granularity; + info.min_marks_for_concurrent_read = (info.min_marks_for_concurrent_read * fixed_index_granularity + max_block_size - 1) / max_block_size * max_block_size / fixed_index_granularity; } - auto total_rows_ = data_part_->index_granularity.getRowsCountInRanges(mark_ranges); - MergeTreeReadPoolPtr pool; pool = std::make_shared( num_streams, - sum_marks, - min_marks_for_concurrent_read, + info.sum_marks, + info.min_marks_for_concurrent_read, std::move(parts_with_ranges), storage_snapshot_, prewhere_info_copy, @@ -925,23 +1066,26 @@ VIBitmapPtr MergeTreeSelectWithHybridSearchProcessor::performPrefilter( for (size_t i = 0; i < num_streams; ++i) { auto algorithm = std::make_unique( - i, pool, min_marks_for_concurrent_read, max_block_size, + i, pool, info.min_marks_for_concurrent_read, max_block_size, settings.preferred_block_size_bytes, settings.preferred_max_column_in_block_size_bytes, - storage_, storage_snapshot_, use_uncompressed_cache_, + storage_, storage_snapshot_, info.use_uncompressed_cache, prewhere_info_copy, actions_settings, reader_settings_, system_columns); auto source = std::make_shared(std::move(algorithm)); if (i == 0) - source->addTotalRowsApprox(total_rows_); + source->addTotalRowsApprox(info.total_rows); pipes.emplace_back(std::move(source)); } - pipe = Pipe::unitePipes(std::move(pipes)); + Pipe pipe = Pipe::unitePipes(std::move(pipes)); + + return getFilterFromPipeline(num_rows, pipe); } else { + /// Read in a single thread auto algorithm = std::make_unique( storage_, storage_snapshot_, @@ -952,7 +1096,7 @@ VIBitmapPtr MergeTreeSelectWithHybridSearchProcessor::performPrefilter( preferred_max_column_in_block_size_bytes_, required_columns_prewhere, mark_ranges, - use_uncompressed_cache_, + info.use_uncompressed_cache, prewhere_info_copy, actions_settings, reader_settings_, @@ -961,44 +1105,19 @@ VIBitmapPtr MergeTreeSelectWithHybridSearchProcessor::performPrefilter( auto source = std::make_shared(std::move(algorithm)); - pipe = Pipe(std::move(source)); - } - - QueryPipelineBuilder builder; - builder.init(std::move(pipe)); - - QueryPipeline filter_pipeline = QueryPipelineBuilder::getPipeline(std::move(builder)); - PullingAsyncPipelineExecutor filter_executor(filter_pipeline); - - size_t num_rows = data_part_->rows_count; + Pipe pipe = Pipe(std::move(source)); - Block block; - VIBitmapPtr filter = std::make_shared(num_rows); - { - OpenTelemetry::SpanHolder span_pipe("MergeTreeSelectWithHybridSearchProcessor::performPrefilter()::StartPipe"); - while (filter_executor.pull(block)) - { - if (block) - { - const PaddedPODArray & col_data = checkAndGetColumn(*block.getByName("_part_offset").column)->getData(); - for (size_t i = 0; i < block.rows(); ++i) - { - filter->set(col_data[i]); - } - } - } + return getFilterFromPipeline(num_rows, pipe); } - - return filter; } VectorAndTextResultInDataParts MergeTreeSelectWithHybridSearchProcessor::selectPartsByVectorAndTextIndexes( - const RangesInDataParts & parts_with_range, + const RangesInDataParts & parts_with_ranges, const StorageMetadataPtr & metadata_snapshot, const SelectQueryInfo & query_info, - const bool support_two_stage_search, + const std::vector & vec_support_two_stage_searches, #if USE_TANTIVY_SEARCH - const Statistics & bm25_stats_in_table, + const TANTIVY::Statistics & bm25_stats_in_table, #endif const PrewhereInfoPtr & prewhere_info_, StorageSnapshotPtr storage_snapshot_, @@ -1013,7 +1132,8 @@ VectorAndTextResultInDataParts MergeTreeSelectWithHybridSearchProcessor::selectP if (!query_info.has_hybrid_search) return parts_with_mix_results; - parts_with_mix_results.resize(parts_with_range.size()); + size_t parts_with_ranges_size = parts_with_ranges.size(); + parts_with_mix_results.resize(parts_with_ranges_size); PrewhereInfoPtr prewhere_info_copy = nullptr; if (prewhere_info_) @@ -1028,45 +1148,47 @@ VectorAndTextResultInDataParts MergeTreeSelectWithHybridSearchProcessor::selectP /// Execute vector scan and text search in this part. auto process_part = [&](size_t part_index) { - auto & part_with_range = parts_with_range[part_index]; + auto & part_with_range = parts_with_ranges[part_index]; auto & data_part = part_with_range.data_part; auto & mark_ranges = part_with_range.ranges; - VectorAndTextResultInDataPart mix_results(part_with_range); + /// Save part_index in parts_with_ranges + VectorAndTextResultInDataPart mix_results(part_index, data_part); /// Handle three cases: vector scan, full-text seach and hybrid search if (query_info.hybrid_search_info) { auto hybrid_search_mgr = std::make_shared(metadata_snapshot, query_info.hybrid_search_info, - context, support_two_stage_search); + context, vec_support_two_stage_searches[0]); #if USE_TANTIVY_SEARCH hybrid_search_mgr->setBM25Stats(bm25_stats_in_table); #endif /// Get vector scan and text search executeSearch(hybrid_search_mgr, data, storage_snapshot_, data_part, part_with_range.alter_conversions, max_block_size, settings.preferred_block_size_bytes, settings.preferred_max_column_in_block_size_bytes, - mark_ranges, prewhere_info_copy, reader_settings_, settings.use_uncompressed_cache, + mark_ranges, prewhere_info_copy, reader_settings_, context, num_streams); if (hybrid_search_mgr) { - mix_results.vector_scan_result = hybrid_search_mgr->getVectorScanResult(); + mix_results.vector_scan_results.emplace_back(hybrid_search_mgr->getVectorScanResult()); mix_results.text_search_result = hybrid_search_mgr->getTextSearchResult(); } } else if (query_info.vector_scan_info) { auto vector_scan_mgr = std::make_shared(metadata_snapshot, query_info.vector_scan_info, - context, support_two_stage_search); + context, vec_support_two_stage_searches); /// Get vector scan executeSearch(vector_scan_mgr, data, storage_snapshot_, data_part, part_with_range.alter_conversions, max_block_size, settings.preferred_block_size_bytes, settings.preferred_max_column_in_block_size_bytes, - mark_ranges, prewhere_info_copy, reader_settings_, settings.use_uncompressed_cache, + mark_ranges, prewhere_info_copy, reader_settings_, context, num_streams); + /// Support multiple distance functions if (vector_scan_mgr && vector_scan_mgr->preComputed()) - mix_results.vector_scan_result = vector_scan_mgr->getSearchResult(); + mix_results.vector_scan_results = vector_scan_mgr->getVectorScanResults(); } else if (query_info.text_search_info) { @@ -1077,7 +1199,7 @@ VectorAndTextResultInDataParts MergeTreeSelectWithHybridSearchProcessor::selectP /// Get vector scan executeSearch(text_search_mgr, data, storage_snapshot_, data_part, part_with_range.alter_conversions, max_block_size, settings.preferred_block_size_bytes, settings.preferred_max_column_in_block_size_bytes, - mark_ranges, prewhere_info_copy, reader_settings_, settings.use_uncompressed_cache, + mark_ranges, prewhere_info_copy, reader_settings_, context, num_streams); if (text_search_mgr && text_search_mgr->preComputed()) @@ -1087,11 +1209,11 @@ VectorAndTextResultInDataParts MergeTreeSelectWithHybridSearchProcessor::selectP parts_with_mix_results[part_index] = std::move(mix_results); }; - size_t num_threads = std::min(num_streams, parts_with_range.size()); + size_t num_threads = std::min(num_streams, parts_with_ranges_size); if (num_threads <= 1) { - for (size_t part_index = 0; part_index < parts_with_range.size(); ++part_index) + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) process_part(part_index); } else @@ -1102,7 +1224,7 @@ VectorAndTextResultInDataParts MergeTreeSelectWithHybridSearchProcessor::selectP CurrentMetrics::MergeTreeDataSelectHybridSearchThreadsActive, num_threads); - for (size_t part_index = 0; part_index < parts_with_range.size(); ++part_index) + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) pool.scheduleOrThrowOnError([&, part_index, thread_group = CurrentThread::getGroup()] { SCOPE_EXIT_SAFE( diff --git a/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.h b/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.h index e085c65e656..132049f3e92 100644 --- a/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.h +++ b/src/VectorIndex/Storages/MergeTreeSelectWithHybridSearchProcessor.h @@ -54,12 +54,12 @@ class MergeTreeSelectWithHybridSearchProcessor final : public MergeTreeSelectAlg /// Execute vector scan, text or hybrid search on all parts /// For two stage search cases, execute first stage vector scan. static VectorAndTextResultInDataParts selectPartsByVectorAndTextIndexes( - const RangesInDataParts & parts_with_range, + const RangesInDataParts & parts_with_ranges, const StorageMetadataPtr & metadata_snapshot, const SelectQueryInfo & query_info, - const bool support_two_stage_search, + const std::vector & vec_support_two_stage_searches, #if USE_TANTIVY_SEARCH - const Statistics & bm25_stats_in_table, + const TANTIVY::Statistics & bm25_stats_in_table, #endif const PrewhereInfoPtr & prewhere_info, StorageSnapshotPtr storage_snapshot, @@ -110,7 +110,6 @@ class MergeTreeSelectWithHybridSearchProcessor final : public MergeTreeSelectAlg MarkRanges mark_ranges, const PrewhereInfoPtr & prewhere_info_copy, const MergeTreeReaderSettings & reader_settings_, - bool use_uncompressed_cache_, ContextPtr context_, size_t max_streams); @@ -126,7 +125,6 @@ class MergeTreeSelectWithHybridSearchProcessor final : public MergeTreeSelectAlg UInt64 preferred_block_size_bytes_, UInt64 preferred_max_column_in_block_size_bytes_, const MergeTreeReaderSettings & reader_settings_, - bool use_uncompressed_cache_, ContextPtr context_, size_t max_streams); diff --git a/src/VectorIndex/Storages/MergeTreeTextSearchManager.cpp b/src/VectorIndex/Storages/MergeTreeTextSearchManager.cpp index 7b13231b91f..d610ed87bbc 100644 --- a/src/VectorIndex/Storages/MergeTreeTextSearchManager.cpp +++ b/src/VectorIndex/Storages/MergeTreeTextSearchManager.cpp @@ -110,14 +110,27 @@ TextSearchResultPtr MergeTreeTextSearchManager::textSearch( bool find_index = false; for (const auto & index_desc : metadata->getSecondaryIndices()) { - if (index_desc.type == TANTIVY_INDEX_NAME && ((from_table_function && index_desc.name == text_search_info->index_name) - || (!from_table_function && index_desc.column_names.size() == 1 && index_desc.column_names[0] == search_column_name))) + if (index_desc.type == TANTIVY_INDEX_NAME) { - OpenTelemetry::SpanHolder span2("MergeTreeTextSearchManager::textSearch()::find_index::initialize index store"); + if (from_table_function) + { + /// Check index name when from table function + if (index_desc.name != text_search_info->index_name) + continue; + } + else /// from hybrid or text search + { + /// Check if index contains the text column name + auto & column_names = index_desc.column_names; + if (std::find(column_names.begin(), column_names.end(), search_column_name) == column_names.end()) + continue; + } + /// Found matched fts index based index name or text column name index_name = index_desc.name; suggestion_index_log = "ALTER TABLE " + db_table_name + " MATERIALIZE INDEX " + index_desc.name; + OpenTelemetry::SpanHolder span2("MergeTreeTextSearchManager::textSearch()::find_index::initialize index store"); /// Initialize TantivyIndexStore auto index_helper = MergeTreeIndexFactory::instance().get(index_desc); if (!index_helper->getDeserializedFormat(data_part->getDataPartStorage(), index_helper->getFileName())) @@ -161,8 +174,13 @@ TextSearchResultPtr MergeTreeTextSearchManager::textSearch( index_name, data_part->name, db_table_name, suggestion_index_log); } + /// Initialize vector for text columns names + std::vector text_column_names; + if (!search_column_name.empty()) + text_column_names.emplace_back(search_column_name); + /// Find index, load index and do text search - rust::cxxbridge1::Vec search_results; + rust::cxxbridge1::Vec search_results; if (filter) { OpenTelemetry::SpanHolder span3("MergeTreeTextSearchManager::textSearch()::data_part_generate_results_with_filter"); @@ -176,7 +194,7 @@ TextSearchResultPtr MergeTreeTextSearchManager::textSearch( filter_bitmap_vector.emplace_back(bitmap[i]); search_results = tantivy_store->bm25SearchWithFilter( - text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k, filter_bitmap_vector); + text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k, filter_bitmap_vector, text_column_names); } else if (data_part->hasLightweightDelete()) { @@ -227,19 +245,19 @@ TextSearchResultPtr MergeTreeTextSearchManager::textSearch( /// Get non empty delete bitmap (from store or data part) OR fail to get delete bitmap from part if (u8_delete_bitmap_vec.empty()) { - search_results = tantivy_store->bm25Search(text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k); + search_results = tantivy_store->bm25Search(text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k, text_column_names); } else { search_results = tantivy_store->bm25SearchWithFilter( - text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k, u8_delete_bitmap_vec); + text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k, u8_delete_bitmap_vec, text_column_names); } } else { OpenTelemetry::SpanHolder span3("MergeTreeTextSearchManager::textSearch()::data_part_generate_results_no_filter"); LOG_DEBUG(log, "Text search no filter"); - search_results = tantivy_store->bm25Search(text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k); + search_results = tantivy_store->bm25Search(text_search_info->query_text, enable_nlq, operator_or, bm25_stats_in_table, k, text_column_names); } for (size_t i = 0; i < search_results.size(); i++) @@ -264,10 +282,9 @@ void MergeTreeTextSearchManager::mergeResult( Columns & pre_result, size_t & read_rows, const ReadRanges & read_ranges, - const Search::DenseBitmapPtr filter, const ColumnUInt64 * part_offset) { - mergeSearchResultImpl(pre_result, read_rows, read_ranges, text_search_result, filter, part_offset); + mergeSearchResultImpl(pre_result, read_rows, read_ranges, text_search_result, part_offset); } } diff --git a/src/VectorIndex/Storages/MergeTreeTextSearchManager.h b/src/VectorIndex/Storages/MergeTreeTextSearchManager.h index a4bad8611df..1860de8f611 100644 --- a/src/VectorIndex/Storages/MergeTreeTextSearchManager.h +++ b/src/VectorIndex/Storages/MergeTreeTextSearchManager.h @@ -69,7 +69,6 @@ class MergeTreeTextSearchManager : public MergeTreeBaseSearchManager Columns & pre_result, size_t & read_rows, const ReadRanges & read_ranges, - const Search::DenseBitmapPtr filter = nullptr, const ColumnUInt64 * part_offset = nullptr) override; bool preComputed() override @@ -80,7 +79,7 @@ class MergeTreeTextSearchManager : public MergeTreeBaseSearchManager CommonSearchResultPtr getSearchResult() override { return text_search_result; } #if USE_TANTIVY_SEARCH - void setBM25Stats(const Statistics & bm25_stats_in_table_) + void setBM25Stats(const TANTIVY::Statistics & bm25_stats_in_table_) { bm25_stats_in_table = bm25_stats_in_table_; } @@ -91,7 +90,7 @@ class MergeTreeTextSearchManager : public MergeTreeBaseSearchManager TextSearchInfoPtr text_search_info; #if USE_TANTIVY_SEARCH - Statistics bm25_stats_in_table; /// total bm25 info from all parts in a table + TANTIVY::Statistics bm25_stats_in_table; /// total bm25 info from all parts in a table #endif /// lock text search result diff --git a/src/VectorIndex/Storages/MergeTreeVSManager.cpp b/src/VectorIndex/Storages/MergeTreeVSManager.cpp index 6c1123e4d0f..79e162e62af 100644 --- a/src/VectorIndex/Storages/MergeTreeVSManager.cpp +++ b/src/VectorIndex/Storages/MergeTreeVSManager.cpp @@ -21,13 +21,11 @@ #include #include - +#include #include #include #include - #include - #include #include #include @@ -36,11 +34,18 @@ #include #include #include +#include #include /// #define profile +namespace CurrentMetrics +{ + extern const Metric MergeTreeDataSelectHybridSearchThreads; + extern const Metric MergeTreeDataSelectHybridSearchThreadsActive; +} + namespace DB { @@ -257,7 +262,7 @@ void MergeTreeVSManager::executeSearchBeforeRead(const MergeTreeData::DataPartPt /// Skip to execute vector scan if already computed if (!preComputed() && vector_scan_info) - vector_scan_result = vectorScan(vector_scan_info->is_batch, data_part); + vector_scan_results = vectorScan(vector_scan_info->is_batch, data_part); } void MergeTreeVSManager::executeSearchWithFilter( @@ -267,196 +272,253 @@ void MergeTreeVSManager::executeSearchWithFilter( { /// Skip to execute vector scan if already computed if (!preComputed() && vector_scan_info) - vector_scan_result = vectorScan(vector_scan_info->is_batch, data_part, read_ranges, filter); + vector_scan_results = vectorScan(vector_scan_info->is_batch, data_part, read_ranges, filter); } -VectorScanResultPtr MergeTreeVSManager::vectorScan( +ManyVectorScanResults MergeTreeVSManager::vectorScan( bool is_batch, const MergeTreeData::DataPartPtr & data_part, const ReadRanges & read_ranges, const VIBitmapPtr filter) { OpenTelemetry::SpanHolder span("MergeTreeVSManager::vectorScan()"); - const VSDescriptions & descs = vector_scan_info->vector_scan_descs; - const VSDescription & desc = descs[0]; - const String search_column_name = desc.search_column_name; + /// Support multiple distance functions + ManyVectorScanResults multiple_vector_scan_results; - VectorScanResultPtr tmp_vector_scan_result = std::make_shared(); + size_t descs_size = vector_scan_info->vector_scan_descs.size(); + multiple_vector_scan_results.resize(descs_size); - tmp_vector_scan_result->result_columns.resize(3); - auto vector_id_column = DataTypeUInt32().createColumn(); - auto distance_column = DataTypeFloat32().createColumn(); - auto label_column = DataTypeUInt32().createColumn(); + /// Do vector scan on a single column + auto process_vector_scan_on_single_col = [&](size_t desc_index) + { + auto & vector_scan_desc = vector_scan_info->vector_scan_descs[desc_index]; + bool support_two_stage_search = vec_support_two_stage_searches[desc_index]; + + /// Do vector scan on a vector column + VectorScanResultPtr vec_scan_result = vectorScanOnSingleColumn(is_batch, data_part, vector_scan_desc, read_ranges, support_two_stage_search, filter); + + /// vectorScanOnSingleColumn() always return a non-null shared_ptr + if (vec_scan_result) + { + /// Add name of distance function column + vec_scan_result->name = vector_scan_desc.column_name; + + /// Add vector scan result on a vector column to results for multiple distances + multiple_vector_scan_results[desc_index] = std::move(vec_scan_result); + } + }; + + size_t num_threads = std::min(max_threads, descs_size); + if (num_threads <= 1) + { + for (size_t desc_index = 0; desc_index < descs_size; ++desc_index) + process_vector_scan_on_single_col(desc_index); + } + else + { + /// Parallel executing vector scan + ThreadPool pool(CurrentMetrics::MergeTreeDataSelectHybridSearchThreads, CurrentMetrics::MergeTreeDataSelectHybridSearchThreadsActive, num_threads); + + for (size_t desc_index = 0; desc_index < descs_size; ++desc_index) + pool.scheduleOrThrowOnError([&, desc_index]() + { + process_vector_scan_on_single_col(desc_index); + }); + + pool.wait(); + } + + return multiple_vector_scan_results; +} + +VectorScanResultPtr MergeTreeVSManager::vectorScanOnSingleColumn( + bool is_batch, + const MergeTreeData::DataPartPtr & data_part, + const VSDescription vector_scan_desc, + const ReadRanges & read_ranges, + const bool support_two_stage_search, + const Search::DenseBitmapPtr filter) +{ + OpenTelemetry::SpanHolder span("MergeTreeVSManager::vectorScanOnSingleColumn()"); + const String search_column_name = vector_scan_desc.search_column_name; + + VectorScanResultPtr tmp_vector_scan_result = std::make_shared(); VectorIndex::VectorDatasetVariantPtr vec_data; - switch (desc.vector_search_type) + switch (vector_scan_desc.vector_search_type) { case Search::DataType::FloatVector: - vec_data = generateVectorDataset(is_batch, desc); + vec_data = generateVectorDataset(is_batch, vector_scan_desc); break; case Search::DataType::BinaryVector: - vec_data = generateVectorDataset(is_batch, desc); + vec_data = generateVectorDataset(is_batch, vector_scan_desc); break; default: - throw Exception(ErrorCodes::LOGICAL_ERROR, "unsupported vector search type"); + throw Exception(ErrorCodes::LOGICAL_ERROR, "unsupported vector search type for column {}", search_column_name); } - UInt64 dim = desc.search_column_dim; - VIParameter search_params = VectorIndex::convertPocoJsonToMap(desc.vector_parameters); - LOG_DEBUG(log, "Search parameters: {}", search_params.toString()); - - int k = desc.topk > 0 ? desc.topk : VectorIndex::DEFAULT_TOPK; - search_params.erase("metric_type"); - - LOG_DEBUG(log, "Set k to {}, dim to {}", k, dim); + VIParameter search_params = VectorIndex::convertPocoJsonToMap(vector_scan_desc.vector_parameters); + if (!search_params.empty()) + { + LOG_DEBUG(log, "Search parameters: {} for vector column {}", search_params.toString(), search_column_name); + search_params.erase("metric_type"); + } - String metric_str; - bool enable_brute_force_for_part = bruteForceSearchEnabled(data_part); + UInt64 dim = vector_scan_desc.search_column_dim; + int k = vector_scan_desc.topk > 0 ? vector_scan_desc.topk : VectorIndex::DEFAULT_TOPK; + LOG_DEBUG(log, "Set k to {}, dim to {} for vector column {}", k, dim, search_column_name); VIWithColumnInPartPtr column_index; if (data_part->vector_index.getColumnIndexByColumnName(search_column_name).has_value()) column_index = data_part->vector_index.getColumnIndexByColumnName(search_column_name).value(); + /// Try to get or load vector index std::vector index_holders; - VIMetric metric; - if (column_index) - { index_holders = column_index->getIndexHolders(data_part->getState() != MergeTreeDataPartState::Outdated); - metric = column_index->getMetric(); - } - else - { - if (desc.vector_search_type == Search::DataType::FloatVector) - metric_str = data_part->storage.getSettings()->float_vector_search_metric_type; - else if (desc.vector_search_type == Search::DataType::BinaryVector) - metric_str = data_part->storage.getSettings()->binary_vector_search_metric_type; - metric = Search::getMetricType( - static_cast(metric_str), desc.vector_search_type); - } - if (!index_holders.empty()) + /// Fail to find vector index, try brute force search if enabled. + if (index_holders.empty()) { - int64_t query_vector_num = 0; - std::visit([&query_vector_num](auto &&vec_data_ptr) - { - query_vector_num = vec_data_ptr->getVectorNum(); - }, vec_data); - - /// find index - for (size_t i = 0; i < index_holders.size(); ++i) + if (bruteForceSearchEnabled(data_part)) { - auto & index_with_meta = index_holders[i]->value(); - OpenTelemetry::SpanHolder span3("MergeTreeVSManager::vectorScan()::find_index::search"); + VIMetric metric; - VIBitmapPtr real_filter = nullptr; - if (filter != nullptr) + if (column_index) + metric = column_index->getMetric(); + else { - real_filter = getRealBitmap(filter, index_with_meta); + String metric_str; + if (vector_scan_desc.vector_search_type == Search::DataType::FloatVector) + metric_str = data_part->storage.getSettings()->float_vector_search_metric_type; + else if (vector_scan_desc.vector_search_type == Search::DataType::BinaryVector) + metric_str = data_part->storage.getSettings()->binary_vector_search_metric_type; + metric = Search::getMetricType( + static_cast(metric_str), vector_scan_desc.vector_search_type); } - if (real_filter != nullptr && !real_filter->any()) - { - /// don't perform vector search if the segment is completely filtered out - continue; - } + VectorScanResultPtr res_without_index; + std::visit([&](auto &&vec_data_ptr) + { + res_without_index = vectorScanWithoutIndex(data_part, read_ranges, filter, vec_data_ptr, search_column_name, static_cast(dim), k, is_batch, metric); + }, vec_data); + return res_without_index; + } + else + { + /// No vector index available and brute force search disabled + tmp_vector_scan_result->computed = false; + return tmp_vector_scan_result; + } + } + + /// vector index search + tmp_vector_scan_result->result_columns.resize(is_batch ? 3 : 2); + auto vector_id_column = DataTypeUInt32().createColumn(); + auto distance_column = DataTypeFloat32().createColumn(); + auto label_column = DataTypeUInt32().createColumn(); + + int64_t query_vector_num = 0; + std::visit([&query_vector_num](auto &&vec_data_ptr) + { + query_vector_num = vec_data_ptr->getVectorNum(); + }, vec_data); + + /// find index + for (size_t i = 0; i < index_holders.size(); ++i) + { + auto & index_with_meta = index_holders[i]->value(); + OpenTelemetry::SpanHolder span3("MergeTreeVSManager::vectorScan()::find_index::search"); + + VIBitmapPtr real_filter = nullptr; + if (filter != nullptr) + { + real_filter = getRealBitmap(filter, index_with_meta); + } + + if (real_filter != nullptr && !real_filter->any()) + { + /// don't perform vector search if the segment is completely filtered out + continue; + } - LOG_DEBUG(log, "Start search: vector num: {}", query_vector_num); + LOG_DEBUG(log, "Start search: vector num: {}", query_vector_num); - /// Although the vector index type support two stage search, the actual built index may fallback to flat. - bool first_stage_only = false; - if (support_two_stage_search && VIWithColumnInPart::supportTwoStageSearch(index_with_meta)) - first_stage_only = true; + /// Although the vector index type support two stage search, the actual built index may fallback to flat. + bool first_stage_only = false; + if (support_two_stage_search && VIWithColumnInPart::supportTwoStageSearch(index_with_meta)) + first_stage_only = true; - LOG_DEBUG(log, "first stage only = {}", first_stage_only); + LOG_DEBUG(log, "first stage only = {}", first_stage_only); - auto search_results = column_index->search(index_with_meta, vec_data, k, real_filter, search_params, first_stage_only); - auto per_id = search_results->getResultIndices(); - auto per_distance = search_results->getResultDistances(); + auto search_results = column_index->search(index_with_meta, vec_data, k, real_filter, search_params, first_stage_only); + auto per_id = search_results->getResultIndices(); + auto per_distance = search_results->getResultDistances(); - /// Update k value to num_reorder in two search stage. - if (first_stage_only) - k = search_results->getNumCandidates(); + /// Update k value to num_reorder in two search stage. + if (first_stage_only) + k = search_results->getNumCandidates(); - if (is_batch) + if (is_batch) + { + OpenTelemetry::SpanHolder span4("MergeTreeVSManager::vectorScan()::find_index::segment_batch_generate_results"); + for (int64_t label = 0; label < k * query_vector_num; ++label) { - OpenTelemetry::SpanHolder span4("MergeTreeVSManager::vectorScan()::find_index::segment_batch_generate_results"); - for (int64_t label = 0; label < k * query_vector_num; ++label) + UInt32 vector_id = static_cast(label / k); + if (per_id[label] > -1) { - UInt32 vector_id = static_cast(label / k); - if (per_id[label] > -1) - { - label_column->insert(per_id[label]); - vector_id_column->insert(vector_id); - distance_column->insert(per_distance[label]); - } + label_column->insert(per_id[label]); + vector_id_column->insert(vector_id); + distance_column->insert(per_distance[label]); } } - else + } + else + { + OpenTelemetry::SpanHolder span4("MergeTreeVSManager::vectorScan()::find_index::segment_generate_results"); + for (int64_t label = 0; label < k; ++label) { - OpenTelemetry::SpanHolder span4("MergeTreeVSManager::vectorScan()::find_index::segment_generate_results"); - for (int64_t label = 0; label < k; ++label) + if (per_id[label] > -1) { - if (per_id[label] > -1) - { - LOG_TRACE(log, "Label: {}, distance: {}", per_id[label], per_distance[label]); - label_column->insert(per_id[label]); - distance_column->insert(per_distance[label]); - } + LOG_TRACE(log, "Vector column: {}, label: {}, distance: {}", search_column_name, per_id[label], per_distance[label]); + label_column->insert(per_id[label]); + distance_column->insert(per_distance[label]); } } } - - if (is_batch) - { - OpenTelemetry::SpanHolder span3("MergeTreeVSManager::vectorScan()::find_index::data_part_batch_generate_results"); - tmp_vector_scan_result->result_columns[1] = std::move(vector_id_column); - tmp_vector_scan_result->result_columns[2] = std::move(distance_column); - } - else - { - OpenTelemetry::SpanHolder span3("MergeTreeVSManager::vectorScan()::find_index::data_part_generate_results"); - tmp_vector_scan_result->result_columns[1] = std::move(distance_column); - } - - tmp_vector_scan_result->computed = true; - tmp_vector_scan_result->result_columns[0] = std::move(label_column); - - return tmp_vector_scan_result; } - else if (enable_brute_force_for_part) + + if (is_batch) { - VectorScanResultPtr res_without_index; - std::visit([&](auto &&vec_data_ptr) - { - res_without_index = vectorScanWithoutIndex(data_part, read_ranges, filter, vec_data_ptr, search_column_name, static_cast(dim), k, is_batch, metric); - }, vec_data); - return res_without_index; + OpenTelemetry::SpanHolder span3("MergeTreeVSManager::vectorScan()::find_index::data_part_batch_generate_results"); + tmp_vector_scan_result->result_columns[1] = std::move(vector_id_column); + tmp_vector_scan_result->result_columns[2] = std::move(distance_column); } else { - /// No vector index available and brute force search disabled - tmp_vector_scan_result->computed = false; - return tmp_vector_scan_result; + OpenTelemetry::SpanHolder span3("MergeTreeVSManager::vectorScan()::find_index::data_part_generate_results"); + tmp_vector_scan_result->result_columns[1] = std::move(distance_column); } + + tmp_vector_scan_result->computed = true; + tmp_vector_scan_result->result_columns[0] = std::move(label_column); + + return tmp_vector_scan_result; } VectorScanResultPtr MergeTreeVSManager::executeSecondStageVectorScan( const MergeTreeData::DataPartPtr & data_part, - const VectorScanInfoPtr vector_scan_info_, + const VSDescription & vector_scan_desc, const VectorScanResultPtr & first_stage_vec_result) { OpenTelemetry::SpanHolder span("MergeTreeVSManager::executeSecondStageVectorScan()"); Poco::Logger * log_ = &Poco::Logger::get("executeSecondStageVectorScan"); - const VSDescriptions & descs = vector_scan_info_->vector_scan_descs; - const VSDescription & desc = descs[0]; - - if (desc.vector_search_type != Search::DataType::FloatVector) + if (vector_scan_desc.vector_search_type != Search::DataType::FloatVector) return first_stage_vec_result; - const String search_column_name = desc.search_column_name; + const String search_column_name = vector_scan_desc.search_column_name; bool brute_force = false; VIWithColumnInPartPtr column_index; std::vector index_holders; @@ -485,21 +547,21 @@ VectorScanResultPtr MergeTreeVSManager::executeSecondStageVectorScan( } /// Determine the top k value, no more than passed in result size. - int k = desc.topk > 0 ? desc.topk : VectorIndex::DEFAULT_TOPK; + int k = vector_scan_desc.topk > 0 ? vector_scan_desc.topk : VectorIndex::DEFAULT_TOPK; size_t num_reorder = first_label_col->size(); if (k > static_cast(num_reorder)) k = static_cast(num_reorder); - LOG_DEBUG(log_, "[executeSecondStageVectorScan] topk = {}, result size from first stage = {}", desc.topk, num_reorder); + LOG_DEBUG(log_, "[executeSecondStageVectorScan] topk = {}, result size from first stage = {}", vector_scan_desc.topk, num_reorder); VectorScanResultPtr tmp_vector_scan_result = std::make_shared(); tmp_vector_scan_result->result_columns.resize(2); auto distance_column = DataTypeFloat32().createColumn(); auto label_column = DataTypeUInt32().createColumn(); - /// Prepare paramters for computeTopDistanceSubset() if index supports two stage search - VectorIndex::VectorDatasetVariantPtr vec_data = generateVectorDataset(false, desc); + /// Prepare parameters for computeTopDistanceSubset() if index supports two stage search + VectorIndex::VectorDatasetVariantPtr vec_data = generateVectorDataset(false, vector_scan_desc); auto first_stage_result = Search::SearchResult::createTopKHolder(1, num_reorder); auto sr_indices = first_stage_result->getResultIndices(); auto sr_distances = first_stage_result->getResultDistances(); @@ -550,7 +612,7 @@ VectorScanResultPtr MergeTreeVSManager::executeSecondStageVectorScan( { if (per_id[label] > -1) { - LOG_TRACE(log_, "Label: {}, distance: {}", per_id[label], per_distance[label]); + LOG_TRACE(log_, "Vector column: {}, label: {}, distance: {}", search_column_name, per_id[label], per_distance[label]); label_column->insert(per_id[label]); distance_column->insert(per_distance[label]); } @@ -570,6 +632,7 @@ VectorScanResultPtr MergeTreeVSManager::executeSecondStageVectorScan( VectorAndTextResultInDataParts MergeTreeVSManager::splitFirstStageVSResult( const VectorAndTextResultInDataParts & parts_with_mix_results, const ScoreWithPartIndexAndLabels & first_stage_top_results, + const VSDescription & vector_scan_desc, Poco::Logger * log) { /// Merge candidate vector results from the same part index into a vector @@ -586,8 +649,7 @@ VectorAndTextResultInDataParts MergeTreeVSManager::splitFirstStageVSResult( /// Construct new vector scan result for data part existing in top candidates for (const auto & mix_results_in_part : parts_with_mix_results) { - const auto & part_with_ranges = mix_results_in_part.part_with_ranges; - size_t part_index = part_with_ranges.part_index_in_query; + size_t part_index = mix_results_in_part.part_index; /// Found data part in first stage top results if (part_index_merged_map.contains(part_index)) @@ -601,7 +663,7 @@ VectorAndTextResultInDataParts MergeTreeVSManager::splitFirstStageVSResult( auto score_column = DataTypeFloat32().createColumn(); auto label_column = DataTypeUInt32().createColumn(); - LOG_TEST(log, "First stage vector scan result for part {}:", part_with_ranges.data_part->name); + LOG_TEST(log, "First stage vector scan result for part {}:", mix_results_in_part.data_part->name); for (const auto & score_with_part_index_label : score_with_part_index_labels) { const auto & label_id = score_with_part_index_label.label_id; @@ -618,8 +680,11 @@ VectorAndTextResultInDataParts MergeTreeVSManager::splitFirstStageVSResult( tmp_vector_scan_result->result_columns[0] = std::move(label_column); tmp_vector_scan_result->result_columns[1] = std::move(score_column); - VectorAndTextResultInDataPart part_with_vector(part_with_ranges); - part_with_vector.vector_scan_result = tmp_vector_scan_result; + /// Add result column name + tmp_vector_scan_result->name = vector_scan_desc.column_name; + + VectorAndTextResultInDataPart part_with_vector(part_index, mix_results_in_part.data_part); + part_with_vector.vector_scan_results.emplace_back(tmp_vector_scan_result); parts_with_vector_result.emplace_back(std::move(part_with_vector)); } @@ -629,20 +694,148 @@ VectorAndTextResultInDataParts MergeTreeVSManager::splitFirstStageVSResult( return parts_with_vector_result; } +SearchResultAndRangesInDataParts MergeTreeVSManager::FilterPartsWithManyVSResults( + const RangesInDataParts & parts_with_ranges, + const std::unordered_map & vector_scan_results_with_part_index, + const Settings & settings, + Poco::Logger * log) +{ + OpenTelemetry::SpanHolder span("MergeTreeVSManager::FilterPartsWithManyVSResults()"); + + std::unordered_map part_index_merged_results_map; + std::unordered_map> part_index_merged_labels_map; + + for (const auto & [col_name, score_with_part_index_labels] : vector_scan_results_with_part_index) + { + /// For each vector scan, merge results from the same part index into a VectorScanResult + std::unordered_map part_index_single_vector_scan_result_map; + for (const auto & score_with_label : score_with_part_index_labels) + { + const auto & part_index = score_with_label.part_index; + const auto & label_id = score_with_label.label_id; + const auto & score = score_with_label.score; + + /// Save all labels from multiple vector scan results + part_index_merged_labels_map[part_index].emplace(label_id); + + /// Construct single vector scan result for each part + if (part_index_single_vector_scan_result_map.contains(part_index)) + { + auto & single_vector_scan_result = part_index_single_vector_scan_result_map[part_index]; + single_vector_scan_result->result_columns[0]->insert(label_id); + single_vector_scan_result->result_columns[1]->insert(score); + } + else + { + VectorScanResultPtr single_vector_scan_result = std::make_shared(); + single_vector_scan_result->result_columns.resize(2); + + auto label_column = DataTypeUInt32().createColumn(); + auto score_column = DataTypeFloat32().createColumn(); + + label_column->insert(label_id); + score_column->insert(score); + + single_vector_scan_result->computed = true; + single_vector_scan_result->name = col_name; + single_vector_scan_result->result_columns[0] = std::move(label_column); + single_vector_scan_result->result_columns[1] = std::move(score_column); + + /// Save in map + part_index_single_vector_scan_result_map[part_index] = single_vector_scan_result; + } + } + + /// For each data part, put single vector scan result into ManyVectorScanResults + for (const auto & [part_index, single_vector_scan_result] : part_index_single_vector_scan_result_map) + part_index_merged_results_map[part_index].emplace_back(single_vector_scan_result); + } + + size_t parts_with_ranges_size = parts_with_ranges.size(); + SearchResultAndRangesInDataParts parts_with_ranges_vector_scan_result; + parts_with_ranges_vector_scan_result.resize(parts_with_ranges_size); + + /// Filter data part with part index in vector scan results, filter mark ranges with label ids + auto filter_part_with_results = [&](size_t part_index) + { + const auto & part_with_ranges = parts_with_ranges[part_index]; + + /// Check if part_index exists in result map + if (part_index_merged_results_map.contains(part_index)) + { + /// Filter mark ranges with label ids in multiple vector scan results + MarkRanges mark_ranges_for_part = part_with_ranges.ranges; + auto & labels_set = part_index_merged_labels_map[part_index]; + filterMarkRangesByLabels(part_with_ranges.data_part, settings, labels_set, mark_ranges_for_part); + + if (!mark_ranges_for_part.empty()) + { + RangesInDataPart ranges(part_with_ranges.data_part, + part_with_ranges.alter_conversions, + part_with_ranges.part_index_in_query, + std::move(mark_ranges_for_part)); + SearchResultAndRangesInDataPart results_with_ranges(std::move(ranges), part_index_merged_results_map[part_index]); + parts_with_ranges_vector_scan_result[part_index] = std::move(results_with_ranges); + } + } + }; + + size_t num_threads = std::min(settings.max_threads, parts_with_ranges_size); + if (num_threads <= 1) + { + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) + filter_part_with_results(part_index); + } + else + { + /// Parallel executing filter parts_in_ranges with total top-k results + ThreadPool pool(CurrentMetrics::MergeTreeDataSelectHybridSearchThreads, CurrentMetrics::MergeTreeDataSelectHybridSearchThreadsActive, num_threads); + + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) + pool.scheduleOrThrowOnError([&, part_index]() + { + filter_part_with_results(part_index); + }); + + pool.wait(); + } + + /// Skip empty search result + size_t next_part = 0; + for (size_t part_index = 0; part_index < parts_with_ranges_size; ++part_index) + { + auto & part_with_results = parts_with_ranges_vector_scan_result[part_index]; + if (part_with_results.multiple_vector_scan_results.empty()) + continue; + + if (next_part != part_index) + std::swap(parts_with_ranges_vector_scan_result[next_part], part_with_results); + ++next_part; + } + + LOG_TRACE(log, "[FilterPartsWithManyVSResults] The size of parts with vector scan results: {}", next_part); + parts_with_ranges_vector_scan_result.resize(next_part); + + return parts_with_ranges_vector_scan_result; +} + void MergeTreeVSManager::mergeResult( Columns & pre_result, size_t & read_rows, const ReadRanges & read_ranges, - const VIBitmapPtr filter, const ColumnUInt64 * part_offset) { if (vector_scan_info && vector_scan_info->is_batch) { - mergeBatchVectorScanResult(pre_result, read_rows, read_ranges, vector_scan_result, filter, part_offset); + mergeBatchVectorScanResult(pre_result, read_rows, read_ranges, vector_scan_results[0], part_offset); + } + else if (search_func_cols_names.size() == 1) + { + mergeSearchResultImpl(pre_result, read_rows, read_ranges, vector_scan_results[0], part_offset); } else { - mergeSearchResultImpl(pre_result, read_rows, read_ranges, vector_scan_result, filter, part_offset); + mergeMultipleVectorScanResults(pre_result, read_rows, read_ranges, vector_scan_results, part_offset); } } @@ -651,7 +844,6 @@ void MergeTreeVSManager::mergeBatchVectorScanResult( size_t & read_rows, const ReadRanges & read_ranges, VectorScanResultPtr tmp_result, - const VIBitmapPtr filter, const ColumnUInt64 * part_offset) { OpenTelemetry::SpanHolder span("MergeTreeVSManager::mergeBatchVectorScanResult()"); @@ -669,119 +861,77 @@ void MergeTreeVSManager::mergeBatchVectorScanResult( final_result.emplace_back(col->cloneEmpty()); } - if (filter) + if (part_offset == nullptr) { - /// merge label and distance result into result columns - size_t current_column_pos = 0; - int range_index = 0; - size_t start_pos = read_ranges[range_index].start_row; - size_t offset = 0; - for (size_t i = 0; i < filter->get_size(); ++i) - { - if (offset >= read_ranges[range_index].row_num) - { - ++range_index; - start_pos = read_ranges[range_index].start_row; - offset = 0; - } + size_t start_pos = 0; + size_t end_pos = 0; + size_t prev_row_num = 0; - if (filter->unsafe_test(i)) + /// when no filter, the prev read result should be continuous, so we just need to scan all result rows and + /// keep results of which the row id is contained in label_column + for (auto & read_range : read_ranges) + { + start_pos = read_range.start_row; + end_pos = read_range.start_row + read_range.row_num; + for (size_t ind = 0; ind < label_column->size(); ++ind) { - for (size_t ind = 0; ind < label_column->size(); ++ind) + if (label_column->getUInt(ind) >= start_pos && label_column->getUInt(ind) < end_pos) { - /// start_pos + offset equals to the real row id - if (label_column->getUInt(ind) == start_pos + offset) + for (size_t i = 0; i < final_result.size(); ++i) { - /// for each result column - for (size_t j = 0; j < final_result.size(); ++j) - { - Field field; - pre_result[j]->get(current_column_pos, field); - final_result[j]->insert(field); - } - final_vector_id_column->insert(vector_id_column->getUInt(ind)); - final_distance_column->insert(distance_column->getFloat32(ind)); + Field field; + pre_result[i]->get(label_column->getUInt(ind) - start_pos + prev_row_num, field); + final_result[i]->insert(field); } + + final_vector_id_column->insert(vector_id_column->getUInt(ind)); + final_distance_column->insert(distance_column->getFloat32(ind)); } - ++current_column_pos; } - ++offset; + prev_row_num += read_range.row_num; } } - else + else // part_offset != nullptr { - if (part_offset == nullptr) + /// when no filter, the prev read result should be continuous, so we just need to scan all result rows and + /// keep results of which the row id is contained in label_column + for (auto & read_range : read_ranges) { - size_t start_pos = 0; - size_t end_pos = 0; - size_t prev_row_num = 0; - - /// when no filter, the prev read result should be continuous, so we just need to scan all result rows and - /// keep results of which the row id is contained in label_column - for (auto & read_range : read_ranges) + const size_t start_pos = read_range.start_row; + const size_t end_pos = read_range.start_row + read_range.row_num; + for (size_t ind = 0; ind < label_column->size(); ++ind) { - start_pos = read_range.start_row; - end_pos = read_range.start_row + read_range.row_num; - for (size_t ind = 0; ind < label_column->size(); ++ind) + const UInt64 physical_pos = label_column->getUInt(ind); + + if (physical_pos >= start_pos && physical_pos < end_pos) { - if (label_column->getUInt(ind) >= start_pos && label_column->getUInt(ind) < end_pos) + const ColumnUInt64::Container & offset_raw_value = part_offset->getData(); + const size_t part_offset_column_size = part_offset->size(); + size_t logic_pos = 0; + bool logic_pos_found = false; + for (size_t j = 0; j < part_offset_column_size; ++j) { - for (size_t i = 0; i < final_result.size(); ++i) + if (offset_raw_value[j] == physical_pos) { - Field field; - pre_result[i]->get(label_column->getUInt(ind) - start_pos + prev_row_num, field); - final_result[i]->insert(field); + logic_pos_found = true; + logic_pos = j; } - - final_vector_id_column->insert(vector_id_column->getUInt(ind)); - final_distance_column->insert(distance_column->getFloat32(ind)); } - } - prev_row_num += read_range.row_num; - } - } - else // part_offset != nullptr - { - /// when no filter, the prev read result should be continuous, so we just need to scan all result rows and - /// keep results of which the row id is contained in label_column - for (auto & read_range : read_ranges) - { - const size_t start_pos = read_range.start_row; - const size_t end_pos = read_range.start_row + read_range.row_num; - for (size_t ind = 0; ind < label_column->size(); ++ind) - { - const UInt64 physical_pos = label_column->getUInt(ind); - if (physical_pos >= start_pos && physical_pos < end_pos) + if (!logic_pos_found) { - const ColumnUInt64::Container & offset_raw_value = part_offset->getData(); - const size_t part_offset_column_size = part_offset->size(); - size_t logic_pos = 0; - bool logic_pos_found = false; - for (size_t j = 0; j < part_offset_column_size; ++j) - { - if (offset_raw_value[j] == physical_pos) - { - logic_pos_found = true; - logic_pos = j; - } - } - - if (!logic_pos_found) - { - continue; - } - - for (size_t i = 0; i < final_result.size(); ++i) - { - Field field; - pre_result[i]->get(logic_pos, field); - final_result[i]->insert(field); - } + continue; + } - final_vector_id_column->insert(vector_id_column->getUInt(ind)); - final_distance_column->insert(distance_column->getFloat32(ind)); + for (size_t i = 0; i < final_result.size(); ++i) + { + Field field; + pre_result[i]->get(logic_pos, field); + final_result[i]->insert(field); } + + final_vector_id_column->insert(vector_id_column->getUInt(ind)); + final_distance_column->insert(distance_column->getFloat32(ind)); } } } @@ -823,7 +973,6 @@ VectorScanResultPtr MergeTreeVSManager::vectorScanWithoutIndex( static VectorIndex::LimiterSharedContext vector_index_context(getNumberOfPhysicalCPUCores() * 2); VectorIndex::ScanThreadLimiter limiter(vector_index_context, log); - NamesAndTypesList cols; /// get search vector column info from part's metadata instead of table's /// Avoid issues when a new vector column has been added and existing old parts don't have it. @@ -1537,4 +1686,205 @@ bool MergeTreeVSManager::bruteForceSearchEnabled(const MergeTreeData::DataPartPt else return enable_brute_force_search; } + +void MergeTreeVSManager::mergeMultipleVectorScanResults( + Columns & pre_result, + size_t & read_rows, + const ReadRanges & read_ranges, + const ManyVectorScanResults multiple_vector_scan_results, + const ColumnUInt64 * part_offset) +{ + OpenTelemetry::SpanHolder span("MergeTreeVSManager::mergeMultipleVectorScanResults()"); + + /// Use search_func_cols_names in base manager + size_t distances_size = search_func_cols_names.size(); + + /// Initialize the sorted map by the result of multiple distances + if (map_labels_distances.empty()) + { + /// Loop vector scan result for multiple distances + for (const auto & tmp_result : multiple_vector_scan_results) + { + if (!tmp_result) + continue; + + const ColumnUInt32 * label_column = checkAndGetColumn(tmp_result->result_columns[0].get()); + const ColumnFloat32 * distance_column = checkAndGetColumn(tmp_result->result_columns[1].get()); + const String distance_col_name = tmp_result->name; + + size_t dist_index = 0; + bool found = false; + for (size_t i = 0; i < distances_size; ++i) + { + if (search_func_cols_names[i] == distance_col_name) + { + dist_index = i; + found = true; + break; + } + } + + if (!found) + { + LOG_DEBUG(log, "Unknown distance column name '{}'", distance_col_name); + continue; + } + + if (!label_column) + { + LOG_DEBUG(log, "Label colum is null for distance column name '{}'", distance_col_name); + continue; + } + + for (size_t ind = 0; ind < label_column->size(); ++ind) + { + /// Get label_id and distance value + auto label_id = label_column->getUInt(ind); + auto distance_score = distance_column->getFloat32(ind); + + if (!map_labels_distances.contains(label_id)) + { + /// Default value for distance is NaN + map_labels_distances[label_id].resize(distances_size, NAN); + } + + map_labels_distances[label_id][dist_index] = distance_score; + } + } + + if (map_labels_distances.size() > 0) + was_result_processed.assign(map_labels_distances.size(), false); + + LOG_TEST(log, "[mergeMultipleVectorScanResults] results size: {}", map_labels_distances.size()); + for (const auto & [label_id, dists_vec]: map_labels_distances) + { + String distances; + for (size_t i = 0; i < dists_vec.size(); ++i) + distances += (" " + toString(dists_vec[i])); + + LOG_TEST(log, "label: {}, distances:{}", label_id, distances); + } + } + + /// create new column vector to save final results + MutableColumns final_result; + LOG_DEBUG(log, "Create final result"); + for (auto & col : pre_result) + { + final_result.emplace_back(col->cloneEmpty()); + } + + /// Create column vector to save multiple distances result + MutableColumns distances_result; + for (size_t i = 0; i < distances_size; ++i) + distances_result.emplace_back(ColumnFloat32::create()); + + if (part_offset == nullptr) + { + LOG_DEBUG(log, "No part offset"); + size_t start_pos = 0; + size_t end_pos = 0; + size_t prev_row_num = 0; + + for (auto & read_range : read_ranges) + { + start_pos = read_range.start_row; + end_pos = read_range.start_row + read_range.row_num; + + /// loop map of labels and distances + size_t map_ind = 0; + for (const auto & [label_value, distance_score_vec] : map_labels_distances) + { + if (was_result_processed[map_ind]) + { + map_ind++; + continue; + } + + if (label_value >= start_pos && label_value < end_pos) + { + const size_t index_of_arr = label_value - start_pos + prev_row_num; + for (size_t i = 0; i < final_result.size(); ++i) + { + Field field; + pre_result[i]->get(index_of_arr, field); + final_result[i]->insert(field); + } + + /// for each distance column + auto distances_score_vec = map_labels_distances[label_value]; + for (size_t col = 0; col < distances_size; ++col) + distances_result[col]->insert(distances_score_vec[col]); + + was_result_processed[map_ind] = true; + map_ind++; + } + } + prev_row_num += read_range.row_num; + } + } + else if (part_offset->size() > 0) + { + LOG_DEBUG(log, "Get part offset"); + + /// When lightweight delete applied, the rowid in the label column cannot be used as index of pre_result. + /// Match the rowid in the value of label col and the value of part_offset to find the correct index. + const ColumnUInt64::Container & offset_raw_value = part_offset->getData(); + size_t part_offset_size = part_offset->size(); + + /// loop map of labels and distances + size_t map_ind = 0; + for (const auto & [label_value, distance_score_vec] : map_labels_distances) + { + if (was_result_processed[map_ind]) + { + map_ind++; + continue; + } + + /// label_value (row id) = part_offset. + /// We can use binary search to quickly locate part_offset for current label. + size_t low = 0; + size_t high = part_offset_size; + while (low < high) + { + const size_t mid = low + (high - low) / 2; + + if (offset_raw_value[mid] == label_value) + { + /// Use the index of part_offset to locate other columns in pre_result and fill final_result. + for (size_t i = 0; i < final_result.size(); ++i) + { + Field field; + pre_result[i]->get(mid, field); + final_result[i]->insert(field); + } + + /// for each distance column + auto distances_score_vec = map_labels_distances[label_value]; + for (size_t col = 0; col < distances_size; ++col) + distances_result[col]->insert(distances_score_vec[col]); + + was_result_processed[map_ind] = true; + map_ind++; + + /// break from binary search loop + break; + } + else if (offset_raw_value[mid] < label_value) + low = mid + 1; + else + high = mid; + } + } + } + + for (size_t i = 0; i < pre_result.size(); ++i) + pre_result[i] = std::move(final_result[i]); + + read_rows = distances_result[0]->size(); + for (size_t i = 0; i < distances_size; ++i) + pre_result.emplace_back(std::move(distances_result[i])); +} + } diff --git a/src/VectorIndex/Storages/MergeTreeVSManager.h b/src/VectorIndex/Storages/MergeTreeVSManager.h index d6a17ccc580..71498dbe6dd 100644 --- a/src/VectorIndex/Storages/MergeTreeVSManager.h +++ b/src/VectorIndex/Storages/MergeTreeVSManager.h @@ -40,23 +40,36 @@ class MergeTreeVSManager : public MergeTreeBaseSearchManager { public: MergeTreeVSManager( - StorageMetadataPtr metadata_, VectorScanInfoPtr vector_scan_info_, ContextPtr context_, bool support_two_stage_search_ = false) - : MergeTreeBaseSearchManager{metadata_, context_, vector_scan_info_ ? vector_scan_info_->vector_scan_descs[0].column_name : ""} + StorageMetadataPtr metadata_, VectorScanInfoPtr vector_scan_info_, ContextPtr context_, std::vector vec_support_two_stage_searches_ = {}) + : MergeTreeBaseSearchManager{metadata_, context_} , vector_scan_info(vector_scan_info_) - , support_two_stage_search(support_two_stage_search_) + , vec_support_two_stage_searches(vec_support_two_stage_searches_) , enable_brute_force_search(context_->getSettingsRef().enable_brute_force_vector_search) + , max_threads(context_->getSettingsRef().max_threads) { + for (const auto & desc : vector_scan_info_->vector_scan_descs) + search_func_cols_names.emplace_back(desc.column_name); } - MergeTreeVSManager(VectorScanResultPtr vec_scan_result_, VectorScanInfoPtr vector_scan_info_) - : MergeTreeBaseSearchManager{nullptr, nullptr, vector_scan_info_ ? vector_scan_info_->vector_scan_descs[0].column_name : ""} + /// From hybrid search or batch_distance + MergeTreeVSManager( + StorageMetadataPtr metadata_, VectorScanInfoPtr vector_scan_info_, ContextPtr context_, bool support_two_stage_search_ = false) + : MergeTreeVSManager(metadata_, vector_scan_info_, context_, std::vector{support_two_stage_search_}) + {} + + /// Multiple vector scan functions + MergeTreeVSManager(const ManyVectorScanResults & vec_scan_results_, VectorScanInfoPtr vector_scan_info_) + : MergeTreeBaseSearchManager{nullptr, nullptr} , vector_scan_info(nullptr) - , vector_scan_result(vec_scan_result_) + , vector_scan_results(vec_scan_results_) { - if (vector_scan_result && vector_scan_result->computed) + if (preComputed()) { LOG_DEBUG(log, "Already have precomputed vector scan result, no need to execute search"); } + + for (const auto & desc : vector_scan_info_->vector_scan_descs) + search_func_cols_names.emplace_back(desc.column_name); } ~MergeTreeVSManager() override = default; @@ -72,40 +85,65 @@ class MergeTreeVSManager : public MergeTreeBaseSearchManager /// If part doesn't have vector index or real index type doesn't support, just use passed in values. static VectorScanResultPtr executeSecondStageVectorScan( const MergeTreeData::DataPartPtr & data_part, - const VectorScanInfoPtr vector_scan_info_, + const VSDescription & vector_scan_desc, const VectorScanResultPtr & first_stage_vec_result); /// Split num_reorder candidates based on part index: part + vector scan results from first stage static VectorAndTextResultInDataParts splitFirstStageVSResult( const VectorAndTextResultInDataParts & parts_with_mix_results, const ScoreWithPartIndexAndLabels & first_stage_top_results, + const VSDescription & vector_scan_desc, + Poco::Logger * log); + + /// Filter parts using total top-k vector scan results from multiple distance functions + /// For every part, select mark ranges to read, and save multiple vector scan results + static SearchResultAndRangesInDataParts FilterPartsWithManyVSResults( + const RangesInDataParts & parts_with_ranges, + const std::unordered_map & vector_scan_results_with_part_index, + const Settings & settings, Poco::Logger * log); void mergeResult( Columns & pre_result, size_t & read_rows, const ReadRanges & read_ranges, - const Search::DenseBitmapPtr filter = nullptr, const ColumnUInt64 * part_offset = nullptr) override; bool preComputed() override { - return vector_scan_result && vector_scan_result->computed; + for (const auto & result : vector_scan_results) + { + if (result && result->computed) + return true; + } + + return false; + } + + /// Return first vector scan result if exists, used for hybrid search + CommonSearchResultPtr getSearchResult() override + { + if (vector_scan_results.size() > 0) + return vector_scan_results[0]; + else + return nullptr; } - CommonSearchResultPtr getSearchResult() override { return vector_scan_result; } + /// Return all vector scan results + ManyVectorScanResults getVectorScanResults() { return vector_scan_results; } private: VectorScanInfoPtr vector_scan_info; - bool support_two_stage_search; /// True if vector index in metadata support two stage search + std::vector vec_support_two_stage_searches; /// True if vector index in metadata support two stage search /// Whether brute force search is enabled based on query setting bool enable_brute_force_search; - /// lock vector scan result - std::mutex mutex; - VectorScanResultPtr vector_scan_result = nullptr; + /// Support multiple distance functions + ManyVectorScanResults vector_scan_results; + std::map> map_labels_distances; /// sorted map with label ids and multiple distances + size_t max_threads; Poco::Logger * log = &Poco::Logger::get("MergeTreeVSManager"); @@ -118,7 +156,7 @@ class MergeTreeVSManager : public MergeTreeBaseSearchManager template <> VectorIndex::VectorDatasetPtr generateVectorDataset(bool is_batch, const VSDescription & desc); - VectorScanResultPtr vectorScan( + ManyVectorScanResults vectorScan( bool is_batch, const MergeTreeData::DataPartPtr & data_part = nullptr, const ReadRanges & read_ranges = ReadRanges(), @@ -142,7 +180,24 @@ class MergeTreeVSManager : public MergeTreeBaseSearchManager size_t & read_rows, const ReadRanges & read_ranges = ReadRanges(), VectorScanResultPtr tmp_result = nullptr, - const Search::DenseBitmapPtr = nullptr, + const ColumnUInt64 * part_offset = nullptr); + + /// Support multiple distance functions + /// vector search on a vector column + VectorScanResultPtr vectorScanOnSingleColumn( + bool is_batch, + const MergeTreeData::DataPartPtr & data_part, + const VSDescription vector_scan_desc, + const ReadRanges & read_ranges, + const bool support_two_stage_search, + const Search::DenseBitmapPtr filter = nullptr); + + /// Merge multiple vector scan results with other columns + void mergeMultipleVectorScanResults( + Columns & pre_result, + size_t & read_rows, + const ReadRanges & read_ranges = ReadRanges(), + const ManyVectorScanResults multiple_vector_scan_results = {}, const ColumnUInt64 * part_offset = nullptr); template diff --git a/src/VectorIndex/Storages/VIBuilderUpdater.cpp b/src/VectorIndex/Storages/VIBuilderUpdater.cpp index d62f86930e1..1322495bb7e 100644 --- a/src/VectorIndex/Storages/VIBuilderUpdater.cpp +++ b/src/VectorIndex/Storages/VIBuilderUpdater.cpp @@ -131,7 +131,7 @@ void VIBuilderUpdater::removeDroppedVectorIndices(const StorageMetadataPtr & met VIParameter params = cache_item.second; /// Support multiple vector indices - /// TODO: Further check whether the paramters are the same. + /// TODO: Further check whether the parameters are the same. for (const auto & vec_index_desc : metadata_snapshot->getVectorIndices()) { if (cache_key.vector_index_name != vec_index_desc.name) diff --git a/src/VectorIndex/Storages/VSDescription.h b/src/VectorIndex/Storages/VSDescription.h index f6f5c59326d..8b09c1315cd 100644 --- a/src/VectorIndex/Storages/VSDescription.h +++ b/src/VectorIndex/Storages/VSDescription.h @@ -54,6 +54,7 @@ struct VSDescription }; using VSDescriptions = std::vector; +using MutableVSDescriptionsPtr = std::shared_ptr; struct VectorScanInfo { diff --git a/src/VectorIndex/Utils/CommonUtils.h b/src/VectorIndex/Utils/CommonUtils.h index 5c485ef7c1e..4fa4945925c 100644 --- a/src/VectorIndex/Utils/CommonUtils.h +++ b/src/VectorIndex/Utils/CommonUtils.h @@ -27,15 +27,20 @@ enum class DataType; namespace DB { +const String BATCH_DISTANCE_FUNCTION = "batch_distance"; +const String DISTANCE_FUNCTION = "distance"; +const String TEXT_SEARCH_FUNCTION = "textsearch"; +const String HYBRID_SEARCH_FUNCTION = "hybridsearch"; + const String SCORE_COLUMN_NAME = "bm25_score"; /// Different search types enum class HybridSearchFuncType { - VECTOR_SCAN = 0, + UNKNOWN_FUNC = 0, + VECTOR_SCAN, TEXT_SEARCH, - HYBRID_SEARCH, - UNKNOWN_FUNC + HYBRID_SEARCH }; class IDataType; @@ -44,13 +49,13 @@ using DataTypePtr = std::shared_ptr; inline bool isDistance(const String & func) { String func_to_low = Poco::toLower(func); - return func_to_low.find("distance") == 0; + return func_to_low.find(DISTANCE_FUNCTION) == 0; } inline bool isBatchDistance(const String & func) { String func_to_low = Poco::toLower(func); - return func_to_low.find("batch_distance") == 0; + return func_to_low.find(BATCH_DISTANCE_FUNCTION) == 0; } inline bool isVectorScanFunc(const String & func) @@ -61,13 +66,13 @@ inline bool isVectorScanFunc(const String & func) inline bool isTextSearch(const String & func) { String func_to_low = Poco::toLower(func); - return func_to_low.find("textsearch") == 0; + return func_to_low.find(TEXT_SEARCH_FUNCTION) == 0; } inline bool isHybridSearch(const String & func) { String func_to_low = Poco::toLower(func); - return func_to_low.find("hybridsearch") == 0; + return func_to_low.find(HYBRID_SEARCH_FUNCTION) == 0; } inline bool isHybridSearchFunc(const String & func) diff --git a/src/VectorIndex/Utils/HybridSearchUtils.cpp b/src/VectorIndex/Utils/HybridSearchUtils.cpp new file mode 100644 index 00000000000..a741ece4871 --- /dev/null +++ b/src/VectorIndex/Utils/HybridSearchUtils.cpp @@ -0,0 +1,316 @@ +/* + * Copyright (2024) MOQI SINGAPORE PTE. LTD. and/or its affiliates + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +/// Replace limit with num_candidates that is equal to limit * hybrid_search_top_k_multiple_base +inline void replaceLimitAST(ASTPtr & ast, UInt64 replaced_limit) +{ + const auto * select_query = ast->as(); + if (!select_query->limitLength()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "No limit in Distributed HybridSearch AST"); + + auto limit_ast = select_query->limitLength()->as(); + if (!limit_ast) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Bad limit in Distributed HybridSearch AST"); + + limit_ast->value = replaced_limit; +} + +/// Use Distributed Hybrid Search AST to create two separate ASTs for vector search and text search +void splitHybridSearchAST( + ASTPtr & hybrid_search_ast, + ASTPtr & vector_search_ast, + ASTPtr & text_search_ast, + int distance_order_by_direction, + UInt64 vector_limit, + UInt64 text_limit, + bool enable_nlq, + String text_operator) +{ + /// Replace the ASTFunction, ASTOrderByElement and LimitAST for Vector Search + { + vector_search_ast = hybrid_search_ast->clone(); + const auto * select_vector_query = vector_search_ast->as(); + + for (auto & child : select_vector_query->select()->children) + { + auto function = child->as(); + if (function && isHybridSearchFunc(function->name)) + { + child = makeASTFunction( + DISTANCE_FUNCTION, function->arguments->children[0]->clone(), function->arguments->children[2]->clone()); + } + + auto identifier = child->as(); + if (!identifier) + continue; + else if (identifier->name() == SCORE_TYPE_COLUMN.name) + { + /// Delete the SCORE_TYPE_COLUMN from the select list + select_vector_query->select()->children.erase( + std::remove(select_vector_query->select()->children.begin(), select_vector_query->select()->children.end(), child), + select_vector_query->select()->children.end()); + } + } + + /// Replace the HybridSearch function with DISTANCE_FUNCTION in the ORDER BY + if (!select_vector_query->orderBy()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "No ORDER BY in Distributed HybridSearch AST"); + for (auto & child : select_vector_query->orderBy()->children) + { + auto * order_by_element = child->as(); + if (!order_by_element || order_by_element->children.empty()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Bad ORDER BY expression AST"); + + auto function = order_by_element->children.at(0)->as(); + if (function && isHybridSearchFunc(function->name)) + { + order_by_element->children.at(0) = makeASTFunction( + DISTANCE_FUNCTION, function->arguments->children[0]->clone(), function->arguments->children[2]->clone()); + order_by_element->direction = distance_order_by_direction; + } + } + + replaceLimitAST(vector_search_ast, vector_limit); + } + + /// Replace the ASTFunction, ASTOrderByElement and LimitAST for Text Search + { + text_search_ast = hybrid_search_ast->clone(); + const auto * select_text_query = text_search_ast->as(); + + auto text_search_function_parameters = std::make_shared(); + text_search_function_parameters->children.push_back(std::make_shared("enable_nlq=" + std::to_string(enable_nlq))); + text_search_function_parameters->children.push_back(std::make_shared("operator=" + text_operator)); + + /// Replace the HybridSearch function with TEXT_SEARCH_FUNCTION in the select list + for (auto & child : select_text_query->select()->children) + { + auto function = child->as(); + if (function && isHybridSearchFunc(function->name)) + { + std::shared_ptr text_search_function = makeASTFunction( + TEXT_SEARCH_FUNCTION, function->arguments->children[1]->clone(), function->arguments->children[3]->clone()); + text_search_function->parameters = text_search_function_parameters->clone(); + text_search_function->children.push_back(text_search_function->parameters); + child = text_search_function; + } + + auto identifier = child->as(); + if (!identifier) + continue; + else if (identifier->name() == SCORE_TYPE_COLUMN.name) + { + /// Delete the SCORE_TYPE_COLUMN from the select list + select_text_query->select()->children.erase( + std::remove(select_text_query->select()->children.begin(), select_text_query->select()->children.end(), child), + select_text_query->select()->children.end()); + } + } + + /// Replace the HybridSearch function with TEXT_SEARCH_FUNCTION in the ORDER BY + if (!select_text_query->orderBy()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "No ORDER BY in Distributed HybridSearch AST"); + for (auto & child : select_text_query->orderBy()->children) + { + auto * order_by_element = child->as(); + if (!order_by_element || order_by_element->children.empty()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Bad ORDER BY expression AST"); + + auto function = order_by_element->children.at(0)->as(); + if (function && isHybridSearchFunc(function->name)) + { + std::shared_ptr text_search_function = makeASTFunction( + TEXT_SEARCH_FUNCTION, function->arguments->children[1]->clone(), function->arguments->children[3]->clone()); + text_search_function->parameters = text_search_function_parameters->clone(); + text_search_function->children.push_back(text_search_function->parameters); + + order_by_element->children.at(0) = text_search_function; + order_by_element->direction = -1; + } + } + + replaceLimitAST(text_search_ast, text_limit); + } +} + +/// RRF_score = 1.0 / (fusion_k + rank(fusion_id_bm25)) + 1.0 / (fusion_k + rank(fusion_id_distance)) +void RankFusion( + std::map, Float32> & fusion_id_with_score, + const ScoreWithPartIndexAndLabels & vec_scan_result_dataset, + const ScoreWithPartIndexAndLabels & text_search_result_dataset, + const UInt64 fusion_k, + Poco::Logger * log) +{ + size_t idx = 0; + for (const auto & vector_score_with_label : vec_scan_result_dataset) + { + auto fusion_id + = std::make_tuple(vector_score_with_label.shard_num, vector_score_with_label.part_index, vector_score_with_label.label_id); + + /// For new (shard_num, part_index, label_id) tuple, map will insert. + /// fusion_id_with_score map saved the fusion score for a (shard_num, part_index, label_id) tuple. + /// AS for single-shard hybrid search, shard_num is always 0. + fusion_id_with_score[fusion_id] += 1.0f / (fusion_k + idx + 1); + idx++; + + LOG_TRACE( + log, + "fusion_id: [{}, {}, {}], ranked_score: {}", + vector_score_with_label.shard_num, + vector_score_with_label.part_index, + vector_score_with_label.label_id, + fusion_id_with_score[fusion_id]); + } + + idx = 0; + for (const auto & text_score_with_label : text_search_result_dataset) + { + auto fusion_id = std::make_tuple(text_score_with_label.shard_num, text_score_with_label.part_index, text_score_with_label.label_id); + + /// Insert or update fusion score for fusion_id + fusion_id_with_score[fusion_id] += 1.0f / (fusion_k + idx + 1); + idx++; + + LOG_TRACE( + log, + "fusion_id: [{}, {}, {}], ranked_score: {}", + text_score_with_label.shard_num, + text_score_with_label.part_index, + text_score_with_label.label_id, + fusion_id_with_score[fusion_id]); + } +} + +/// RSF_score = normalized_bm25_score * fusion_weight + normalized_distance_score * (1 - fusion_weight) +void RelativeScoreFusion( + std::map, Float32> & fusion_id_with_score, + const ScoreWithPartIndexAndLabels & vec_scan_result_dataset, + const ScoreWithPartIndexAndLabels & text_search_result_dataset, + const Float32 fusion_weight, + const Int8 vector_scan_direction, + Poco::Logger * log) +{ + /// Normalize text search score + std::vector norm_score; + computeNormalizedScore(text_search_result_dataset, norm_score, log); + + LOG_TRACE(log, "Text Search Scores:"); + for (size_t idx = 0; idx < text_search_result_dataset.size(); idx++) + { + auto fusion_id = std::make_tuple( + text_search_result_dataset[idx].shard_num, + text_search_result_dataset[idx].part_index, + text_search_result_dataset[idx].label_id); + + LOG_TRACE( + log, + "fusion_id=[{}, {}, {}], origin_score={}, norm_score={}", + text_search_result_dataset[idx].shard_num, + text_search_result_dataset[idx].part_index, + text_search_result_dataset[idx].label_id, + text_search_result_dataset[idx].score, + norm_score[idx]); + + fusion_id_with_score[fusion_id] = norm_score[idx] * fusion_weight; + } + + /// Normalize vector search score + norm_score.clear(); + computeNormalizedScore(vec_scan_result_dataset, norm_score, log); + + LOG_TRACE(log, "Vector Search Scores:"); + for (size_t idx = 0; idx < vec_scan_result_dataset.size(); idx++) + { + auto fusion_id = std::make_tuple( + vec_scan_result_dataset[idx].shard_num, vec_scan_result_dataset[idx].part_index, vec_scan_result_dataset[idx].label_id); + + LOG_TRACE( + log, + "fusion_id=[{}, {}, {}], origin_score={}, norm_score={}", + vec_scan_result_dataset[idx].shard_num, + vec_scan_result_dataset[idx].part_index, + vec_scan_result_dataset[idx].label_id, + vec_scan_result_dataset[idx].score, + norm_score[idx]); + + Float32 fusion_distance_score = 0; + + /// 1 - ascending, -1 - descending + if (vector_scan_direction == -1) + fusion_distance_score = norm_score[idx] * (1 - fusion_weight); + else + fusion_distance_score = (1 - norm_score[idx]) * (1 - fusion_weight); + + /// Insert or update fusion score for fusion_id + fusion_id_with_score[fusion_id] += fusion_distance_score; + } +} + +void computeNormalizedScore( + const ScoreWithPartIndexAndLabels & search_result_dataset, std::vector & norm_score, Poco::Logger * log) +{ + const auto result_size = search_result_dataset.size(); + if (result_size == 0) + { + LOG_DEBUG(log, "search result is empty"); + return; + } + + norm_score.reserve(result_size); + + /// The search_result_dataset is already ordered + /// As for bm25 score and metric_type=IP, the scores are in ascending order; otherwise, it is in descending order. + Float32 min_score, max_score, min_max_scale; + + /// Here assume the scores in score column are ordered in descending order. + min_score = search_result_dataset[result_size - 1].score; + max_score = search_result_dataset[0].score; + + /// When min_score == max_score, all scores are the same, so the normalized score is 1.0 + if (min_score == max_score) + { + for (size_t idx = 0; idx < result_size; idx++) + norm_score.emplace_back(1.0f); + return; + } + else if (min_score > max_score) /// ASC + { + std::swap(min_score, max_score); + } + + min_max_scale = max_score - min_score; + for (size_t idx = 0; idx < result_size; idx++) + { + Float32 normalizing_score = (search_result_dataset[idx].score - min_score) / min_max_scale; + norm_score.emplace_back(normalizing_score); + } +} + +} diff --git a/src/VectorIndex/Utils/HybridSearchUtils.h b/src/VectorIndex/Utils/HybridSearchUtils.h new file mode 100644 index 00000000000..633096a4c19 --- /dev/null +++ b/src/VectorIndex/Utils/HybridSearchUtils.h @@ -0,0 +1,61 @@ +/* + * Copyright (2024) MOQI SINGAPORE PTE. LTD. and/or its affiliates + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +using ASTPtr = std::shared_ptr; + +const String HYBRID_SEARCH_SCORE_COLUMN_NAME = "HybridSearch_func"; + +/// Distributed Hybrid Search additional columns +const NameAndTypePair SCORE_TYPE_COLUMN{"_distributed_hybrid_search_score_type", std::make_shared()}; + +void splitHybridSearchAST( + ASTPtr & hybrid_search_ast, + ASTPtr & vector_search_ast, + ASTPtr & text_search_ast, + int distance_order_by_direction, + UInt64 vector_limit, + UInt64 text_limit, + bool enable_nlq, + String text_operator); + +void RankFusion( + std::map, Float32> & fusion_id_with_score, + const ScoreWithPartIndexAndLabels & vec_scan_result_dataset, + const ScoreWithPartIndexAndLabels & text_search_result_dataset, + const UInt64 fusion_k, + Poco::Logger * log); + +void RelativeScoreFusion( + std::map, Float32> & fusion_id_with_score, + const ScoreWithPartIndexAndLabels & vec_scan_result_dataset, + const ScoreWithPartIndexAndLabels & text_search_result_dataset, + const Float32 fusion_weight, + const Int8 vector_scan_direction, + Poco::Logger * log); + +void computeNormalizedScore( + const ScoreWithPartIndexAndLabels & search_result_dataset, std::vector & norm_score, Poco::Logger * log); +} diff --git a/src/VectorIndex/Utils/VSUtils.cpp b/src/VectorIndex/Utils/VSUtils.cpp index 7b3739c3cf9..799fac0e904 100644 --- a/src/VectorIndex/Utils/VSUtils.cpp +++ b/src/VectorIndex/Utils/VSUtils.cpp @@ -39,7 +39,7 @@ void filterMarkRangesByVectorScanResult(MergeTreeData::DataPartPtr part, MergeTr void filterMarkRangesBySearchResult(MergeTreeData::DataPartPtr part, const Settings & settings, CommonSearchResultPtr common_search_result, MarkRanges & mark_ranges) { - OpenTelemetry::SpanHolder span("filterMarkRangesByVectorScanResult()"); + OpenTelemetry::SpanHolder span("filterMarkRangesBySearchResult()"); MarkRanges res; if (!common_search_result || !common_search_result->computed) diff --git a/tests/integration/test_mqvs_distributed_hybrid_search/__init__.py b/tests/integration/test_mqvs_distributed_hybrid_search/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/test_mqvs_distributed_hybrid_search/configs/remote_servers.xml b/tests/integration/test_mqvs_distributed_hybrid_search/configs/remote_servers.xml new file mode 100644 index 00000000000..791af83a2d6 --- /dev/null +++ b/tests/integration/test_mqvs_distributed_hybrid_search/configs/remote_servers.xml @@ -0,0 +1,28 @@ + + + + + true + + node1 + 9000 + + + node2 + 9000 + + + + true + + node3 + 9000 + + + node4 + 9000 + + + + + diff --git a/tests/integration/test_mqvs_distributed_hybrid_search/test.py b/tests/integration/test_mqvs_distributed_hybrid_search/test.py new file mode 100644 index 00000000000..5492c2dd2ef --- /dev/null +++ b/tests/integration/test_mqvs_distributed_hybrid_search/test.py @@ -0,0 +1,121 @@ +import pytest +import time +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) + +# shard1 +node_s1_r1 = cluster.add_instance("node1", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) +node_s1_r2 = cluster.add_instance("node2", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) + +# shard2 +node_s2_r1 = cluster.add_instance("node3", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) +node_s2_r2 = cluster.add_instance("node4", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) + +data = [ + (1, [1.0, 1.0, 1.0, 1.0, 1.0], 'The mayor announced a new initiative to revitalize the downtown area. This project will include the construction of new parks and the renovation of historic buildings.'), + (2, [2.0, 2.0, 2.0, 2.0, 2.0], 'Local schools are introducing a new curriculum focused on science and technology. The goal is to better prepare students for careers in STEM fields.'), + (3, [3.0, 3.0, 3.0, 3.0, 3.0], 'A new community center is opening next month, offering a variety of programs for residents of all ages. Activities include fitness classes, art workshops, and social events.'), + (4, [4.0, 4.0, 4.0, 4.0, 4.0], 'The city council has approved a plan to improve public transportation. This includes expanding bus routes and adding more frequent services during peak hours.'), + (5, [5.0, 5.0, 5.0, 5.0, 5.0], 'A new library is being built in the west side of the city. The library will feature modern facilities, including a digital media lab and community meeting rooms.'), + (6, [6.0, 6.0, 6.0, 6.0, 6.0], 'The local hospital has received funding to upgrade its emergency department. The improvements will enhance patient care and reduce wait times.'), + (7, [7.0, 7.0, 7.0, 7.0, 7.0], 'A new startup accelerator has been launched to support tech entrepreneurs. The program offers mentoring, networking opportunities, and access to investors.'), + (8, [8.0, 8.0, 8.0, 8.0, 8.0], 'The city is hosting a series of public workshops on climate change. The sessions aim to educate residents on how to reduce their carbon footprint.'), + (9, [9.0, 9.0, 9.0, 9.0, 9.0], 'A popular local restaurant is expanding with a new location in the downtown area. The restaurant is known for its farm-to-table cuisine and sustainable practices.'), + (10, [10.0, 10.0, 10.0, 10.0, 10.0], 'The annual arts festival is set to begin next week, featuring performances, exhibitions, and workshops by local and international artists.'), + (11, [11.0, 11.0, 11.0, 11.0, 11.0], 'The city is implementing new measures to improve air quality, including stricter emissions standards for industrial facilities and incentives for electric vehicles.'), + (12, [12.0, 12.0, 12.0, 12.0, 12.0], 'A local nonprofit is organizing a food drive to support families in need. Donations of non-perishable food items can be dropped off at designated locations.'), + (13, [13.0, 13.0, 13.0, 13.0, 13.0], 'The community garden project is expanding, with new plots available for residents to grow their own vegetables and herbs. The garden promotes healthy eating and sustainability.'), + (14, [14.0, 14.0, 14.0, 14.0, 14.0], 'The police department is introducing body cameras for officers to increase transparency and accountability. The initiative is part of broader efforts to build trust with the community.'), + (15, [15.0, 15.0, 15.0, 15.0, 15.0], 'A new public swimming pool is opening this summer, offering swimming lessons and recreational activities for all ages. The pool is part of a larger effort to promote health and wellness.'), + (16, [16.0, 16.0, 16.0, 16.0, 16.0], 'The city is launching a campaign to promote recycling and reduce waste. Residents are encouraged to participate by recycling household items and composting organic waste.'), + (17, [17.0, 17.0, 17.0, 17.0, 17.0], 'A local theater group is performing a series of classic plays at the community center. The performances aim to make theater accessible to a wider audience.'), + (18, [18.0, 18.0, 18.0, 18.0, 18.0], 'The city is investing in renewable energy projects, including the installation of solar panels on public buildings and the development of wind farms.'), + (19, [19.0, 19.0, 19.0, 19.0, 19.0], 'A new sports complex is being built to provide facilities for basketball, soccer, and other sports. The complex will also include a fitness center and walking trails.'), + (20, [20.0, 20.0, 20.0, 20.0, 20.0], 'The city is hosting a series of workshops on financial literacy, aimed at helping residents manage their money and plan for the future. Topics include budgeting, saving, and investing.'), + (21, [21.0, 21.0, 21.0, 21.0, 21.0], 'A new art exhibit is opening at the city museum, featuring works by contemporary artists from around the world. The exhibit aims to foster a greater appreciation for modern art.'), + (22, [22.0, 22.0, 22.0, 22.0, 22.0], 'The local animal shelter is holding an adoption event this weekend. Dogs, cats, and other pets are available for adoption, and volunteers will be on hand to provide information.'), + (23, [23.0, 23.0, 23.0, 23.0, 23.0], 'The city is upgrading its water infrastructure to ensure a reliable supply of clean water. The project includes replacing old pipes and installing new treatment facilities.'), + (24, [24.0, 24.0, 24.0, 24.0, 24.0], 'A new technology incubator has opened to support startups in the tech sector. The incubator provides office space, resources, and mentorship to help entrepreneurs succeed.'), + (25, [25.0, 25.0, 25.0, 25.0, 25.0], 'The city is planning to build a new bike lane network to promote cycling as a healthy and environmentally friendly mode of transportation. The project includes dedicated bike lanes and bike-sharing stations.'), + (26, [26.0, 26.0, 26.0, 26.0, 26.0], 'The local farmers market is reopening for the season, offering fresh produce, artisanal goods, and handmade crafts from local vendors.'), + (27, [27.0, 27.0, 27.0, 27.0, 27.0], 'A new educational program is being launched to support early childhood development. The program provides resources and training for parents and caregivers.'), + (28, [28.0, 28.0, 28.0, 28.0, 28.0], 'The city is organizing a series of concerts in the park, featuring performances by local bands and musicians. The concerts are free and open to the public.'), + (29, [29.0, 29.0, 29.0, 29.0, 29.0], 'A new senior center is opening, offering programs and services for older adults. Activities include fitness classes, educational workshops, and social events.'), + (30, [30.0, 30.0, 30.0, 30.0, 30.0], 'The city is implementing a new traffic management system to reduce congestion and improve safety. The system includes synchronized traffic lights and real-time traffic monitoring.'), + (31, [31.0, 31.0, 31.0, 31.0, 31.0], 'A new community outreach program is being launched to support at-risk youth. The program provides mentoring, tutoring, and recreational activities.'), + (32, [32.0, 32.0, 32.0, 32.0, 32.0], 'The city is hosting a series of public forums to discuss plans for future development. Residents are encouraged to attend and provide feedback on proposed projects.'), + (33, [33.0, 33.0, 33.0, 33.0, 33.0], 'A new public art installation is being unveiled in the downtown area. The installation features sculptures and murals by local artists and aims to beautify the urban landscape.'), + (34, [34.0, 34.0, 34.0, 34.0, 34.0], 'The local university is launching a new research center focused on sustainable development. The center will conduct research and provide education on environmental issues.'), + (35, [35.0, 35.0, 35.0, 35.0, 35.0], 'The city is planning to expand its public Wi-Fi network to provide free internet access in parks, libraries, and other public spaces.'), + (36, [36.0, 36.0, 36.0, 36.0, 36.0], 'A new community health clinic is opening, offering medical, dental, and mental health services. The clinic aims to provide affordable healthcare to underserved populations.'), + (37, [37.0, 37.0, 37.0, 37.0, 37.0], 'The city is implementing a new emergency alert system to provide residents with real-time information during emergencies. The system includes mobile alerts and social media updates.'), + (38, [38.0, 38.0, 38.0, 38.0, 38.0], 'A local nonprofit is organizing a job fair to connect job seekers with employers. The fair will feature workshops on resume writing, interview skills, and job search strategies.'), + (39, [39.0, 39.0, 39.0, 39.0, 39.0], 'The city is hosting a series of environmental cleanup events, encouraging residents to participate in efforts to clean up parks, rivers, and other natural areas.'), + (40, [40.0, 40.0, 40.0, 40.0, 40.0], 'A new fitness trail is being built in the local park, featuring exercise stations and informational signs to promote physical activity and wellness.'), + (41, [41.0, 41.0, 41.0, 41.0, 41.0], 'The city is launching a new initiative to support small businesses, offering grants, training, and resources to help entrepreneurs grow their businesses.'), + (42, [42.0, 42.0, 42.0, 42.0, 42.0], 'A new art school is opening, offering classes and workshops for aspiring artists of all ages. The school aims to foster creativity and provide a supportive environment for artistic development.'), + (43, [43.0, 43.0, 43.0, 43.0, 43.0], 'The city is planning to improve its public transportation system by introducing electric buses and expanding routes to underserved areas.'), + (44, [44.0, 44.0, 44.0, 44.0, 44.0], 'A new music festival is being organized to celebrate local talent and bring the community together. The festival will feature performances by local bands, food stalls, and family-friendly activities.'), + (45, [45.0, 45.0, 45.0, 45.0, 45.0], 'The city is implementing new measures to protect green spaces, including the creation of new parks and the preservation of existing natural areas.'), + (46, [46.0, 46.0, 46.0, 46.0, 46.0], 'A new housing project is being developed to provide affordable homes for low-income families. The project includes energy-efficient buildings and community amenities.'), + (47, [47.0, 47.0, 47.0, 47.0, 47.0], 'The city is hosting a series of workshops on entrepreneurship, providing training and resources for aspiring business owners. Topics include business planning, marketing, and finance.'), + (48, [48.0, 48.0, 48.0, 48.0, 48.0], 'A new public garden is being created to provide a space for residents to relax and enjoy nature. The garden will feature walking paths, benches, and a variety of plants.'), + (49, [49.0, 49.0, 49.0, 49.0, 49.0], 'The city is launching a campaign to promote public health, encouraging residents to get vaccinated, exercise regularly, and eat a balanced diet.'), + (50, [50.0, 50.0, 50.0, 50.0, 50.0], 'A new community theater is opening, offering performances, workshops, and classes for residents of all ages. The theater aims to make the performing arts accessible to everyone.') +] + +@pytest.fixture(scope="module") +def started_cluster(): + try: + cluster.start() + + for node in [node_s1_r1, node_s1_r2]: + node.query( + """ + CREATE TABLE IF NOT EXISTS local_table(id UInt32, vector Array(Float32), text String, CONSTRAINT check_length CHECK length(vector) = 5) ENGINE = ReplicatedMergeTree('/clickhouse/tables/shard_1/local_table', '{replica}') ORDER BY id; + """.format( + replica=node.name + ) + ) + + for node in [node_s2_r1, node_s2_r2]: + node.query( + """ + CREATE TABLE IF NOT EXISTS local_table(id UInt32, vector Array(Float32), text String, CONSTRAINT check_length CHECK length(vector) = 5) ENGINE = ReplicatedMergeTree('/clickhouse/tables/shard_2/local_table', '{replica}') ORDER BY id; + """.format( + replica=node.name + ) + ) + + node_s1_r1.query("CREATE TABLE distributed_table(id UInt32, vector Array(Float32), text String, CONSTRAINT check_length CHECK length(vector) = 5) ENGINE = Distributed(test_cluster, default, local_table, id);") + node_s1_r1.query("INSERT INTO distributed_table (id, vector, text) VALUES {}".format(", ".join(map(str, data)))) + node_s1_r1.query("ALTER TABLE local_table ON CLUSTER test_cluster ADD INDEX fts_ind text TYPE fts;") + node_s1_r1.query("ALTER TABLE local_table ON CLUSTER test_cluster MATERIALIZE INDEX fts_ind;") + + # Create a MergeTree table with the same data as baseline. + node_s1_r1.query("CREATE TABLE test_table (id UInt32, vector Array(Float32), text String, CONSTRAINT check_length CHECK length(vector) = 5) ENGINE = MergeTree ORDER BY id;") + node_s1_r1.query("INSERT INTO test_table (id, vector, text) VALUES {}".format(", ".join(map(str, data)))) + node_s1_r1.query("ALTER TABLE test_table ADD INDEX fts_ind text TYPE fts;") + node_s1_r1.query("ALTER TABLE test_table MATERIALIZE INDEX fts_ind;") + + time.sleep(2) + + yield cluster + + finally: + cluster.shutdown() + +def test_distributed_hybrid_search(started_cluster): + + ## Test RSF fusion_type + assert node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city') AS score FROM test_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") == node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city') AS score FROM distributed_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") + + ## Test RRF fusion_type + assert node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RRF')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'built') AS score FROM test_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") == node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RRF')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'built') AS score FROM distributed_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") + + ## Test enable_nlq and operator + assert node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=false', 'operator=OR')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM test_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") == node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=false', 'operator=OR')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM distributed_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") + assert node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=false', 'operator=AND')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM test_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") == node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=false', 'operator=AND')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM distributed_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") + assert node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=true', 'operator=OR')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM test_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") == node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=true', 'operator=OR')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM distributed_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") + assert node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=true', 'operator=AND')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM test_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") == node_s1_r1.query("SELECT id, HybridSearch('fusion_type=RSF', 'enable_nlq=true', 'operator=AND')(vector, text, [10.1, 20.3, 30.5, 40.7, 50.9], 'city AND new') AS score FROM distributed_table ORDER BY score DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1, enable_brute_force_vector_search = 1") diff --git a/tests/integration/test_mqvs_distributed_text_search/__init__.py b/tests/integration/test_mqvs_distributed_text_search/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/test_mqvs_distributed_text_search/configs/remote_servers.xml b/tests/integration/test_mqvs_distributed_text_search/configs/remote_servers.xml new file mode 100644 index 00000000000..791af83a2d6 --- /dev/null +++ b/tests/integration/test_mqvs_distributed_text_search/configs/remote_servers.xml @@ -0,0 +1,28 @@ + + + + + true + + node1 + 9000 + + + node2 + 9000 + + + + true + + node3 + 9000 + + + node4 + 9000 + + + + + diff --git a/tests/integration/test_mqvs_distributed_text_search/test.py b/tests/integration/test_mqvs_distributed_text_search/test.py new file mode 100644 index 00000000000..1c0c73c701e --- /dev/null +++ b/tests/integration/test_mqvs_distributed_text_search/test.py @@ -0,0 +1,121 @@ +import pytest +import time +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) + +# shard1 +node_s1_r1 = cluster.add_instance("node1", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) +node_s1_r2 = cluster.add_instance("node2", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) + +# shard2 +node_s2_r1 = cluster.add_instance("node3", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) +node_s2_r2 = cluster.add_instance("node4", main_configs=["configs/remote_servers.xml"], with_zookeeper=True) + +data = [ + (1, 'The mayor announced a new initiative to revitalize the downtown area. This project will include the construction of new parks and the renovation of historic buildings.'), + (2, 'Local schools are introducing a new curriculum focused on science and technology. The goal is to better prepare students for careers in STEM fields.'), + (3, 'A new community center is opening next month, offering a variety of programs for residents of all ages. Activities include fitness classes, art workshops, and social events.'), + (4, 'The city council has approved a plan to improve public transportation. This includes expanding bus routes and adding more frequent services during peak hours.'), + (5, 'A new library is being built in the west side of the city. The library will feature modern facilities, including a digital media lab and community meeting rooms.'), + (6, 'The local hospital has received funding to upgrade its emergency department. The improvements will enhance patient care and reduce wait times.'), + (7, 'A new startup accelerator has been launched to support tech entrepreneurs. The program offers mentoring, networking opportunities, and access to investors.'), + (8, 'The city is hosting a series of public workshops on climate change. The sessions aim to educate residents on how to reduce their carbon footprint.'), + (9, 'A popular local restaurant is expanding with a new location in the downtown area. The restaurant is known for its farm-to-table cuisine and sustainable practices.'), + (10, 'The annual arts festival is set to begin next week, featuring performances, exhibitions, and workshops by local and international artists.'), + (11, 'The city is implementing new measures to improve air quality, including stricter emissions standards for industrial facilities and incentives for electric vehicles.'), + (12, 'A local nonprofit is organizing a food drive to support families in need. Donations of non-perishable food items can be dropped off at designated locations.'), + (13, 'The community garden project is expanding, with new plots available for residents to grow their own vegetables and herbs. The garden promotes healthy eating and sustainability.'), + (14, 'The police department is introducing body cameras for officers to increase transparency and accountability. The initiative is part of broader efforts to build trust with the community.'), + (15, 'A new public swimming pool is opening this summer, offering swimming lessons and recreational activities for all ages. The pool is part of a larger effort to promote health and wellness.'), + (16, 'The city is launching a campaign to promote recycling and reduce waste. Residents are encouraged to participate by recycling household items and composting organic waste.'), + (17, 'A local theater group is performing a series of classic plays at the community center. The performances aim to make theater accessible to a wider audience.'), + (18, 'The city is investing in renewable energy projects, including the installation of solar panels on public buildings and the development of wind farms.'), + (19, 'A new sports complex is being built to provide facilities for basketball, soccer, and other sports. The complex will also include a fitness center and walking trails.'), + (20, 'The city is hosting a series of workshops on financial literacy, aimed at helping residents manage their money and plan for the future. Topics include budgeting, saving, and investing.'), + (21, 'A new art exhibit is opening at the city museum, featuring works by contemporary artists from around the world. The exhibit aims to foster a greater appreciation for modern art.'), + (22, 'The local animal shelter is holding an adoption event this weekend. Dogs, cats, and other pets are available for adoption, and volunteers will be on hand to provide information.'), + (23, 'The city is upgrading its water infrastructure to ensure a reliable supply of clean water. The project includes replacing old pipes and installing new treatment facilities.'), + (24, 'A new technology incubator has opened to support startups in the tech sector. The incubator provides office space, resources, and mentorship to help entrepreneurs succeed.'), + (25, 'The city is planning to build a new bike lane network to promote cycling as a healthy and environmentally friendly mode of transportation. The project includes dedicated bike lanes and bike-sharing stations.'), + (26, 'The local farmers market is reopening for the season, offering fresh produce, artisanal goods, and handmade crafts from local vendors.'), + (27, 'A new educational program is being launched to support early childhood development. The program provides resources and training for parents and caregivers.'), + (28, 'The city is organizing a series of concerts in the park, featuring performances by local bands and musicians. The concerts are free and open to the public.'), + (29, 'A new senior center is opening, offering programs and services for older adults. Activities include fitness classes, educational workshops, and social events.'), + (30, 'The city is implementing a new traffic management system to reduce congestion and improve safety. The system includes synchronized traffic lights and real-time traffic monitoring.'), + (31, 'A new community outreach program is being launched to support at-risk youth. The program provides mentoring, tutoring, and recreational activities.'), + (32, 'The city is hosting a series of public forums to discuss plans for future development. Residents are encouraged to attend and provide feedback on proposed projects.'), + (33, 'A new public art installation is being unveiled in the downtown area. The installation features sculptures and murals by local artists and aims to beautify the urban landscape.'), + (34, 'The local university is launching a new research center focused on sustainable development. The center will conduct research and provide education on environmental issues.'), + (35, 'The city is planning to expand its public Wi-Fi network to provide free internet access in parks, libraries, and other public spaces.'), + (36, 'A new community health clinic is opening, offering medical, dental, and mental health services. The clinic aims to provide affordable healthcare to underserved populations.'), + (37, 'The city is implementing a new emergency alert system to provide residents with real-time information during emergencies. The system includes mobile alerts and social media updates.'), + (38, 'A local nonprofit is organizing a job fair to connect job seekers with employers. The fair will feature workshops on resume writing, interview skills, and job search strategies.'), + (39, 'The city is hosting a series of environmental cleanup events, encouraging residents to participate in efforts to clean up parks, rivers, and other natural areas.'), + (40, 'A new fitness trail is being built in the local park, featuring exercise stations and informational signs to promote physical activity and wellness.'), + (41, 'The city is launching a new initiative to support small businesses, offering grants, training, and resources to help entrepreneurs grow their businesses.'), + (42, 'A new art school is opening, offering classes and workshops for aspiring artists of all ages. The school aims to foster creativity and provide a supportive environment for artistic development.'), + (43, 'The city is planning to improve its public transportation system by introducing electric buses and expanding routes to underserved areas.'), + (44, 'A new music festival is being organized to celebrate local talent and bring the community together. The festival will feature performances by local bands, food stalls, and family-friendly activities.'), + (45, 'The city is implementing new measures to protect green spaces, including the creation of new parks and the preservation of existing natural areas.'), + (46, 'A new housing project is being developed to provide affordable homes for low-income families. The project includes energy-efficient buildings and community amenities.'), + (47, 'The city is hosting a series of workshops on entrepreneurship, providing training and resources for aspiring business owners. Topics include business planning, marketing, and finance.'), + (48, 'A new public garden is being created to provide a space for residents to relax and enjoy nature. The garden will feature walking paths, benches, and a variety of plants.'), + (49, 'The city is launching a campaign to promote public health, encouraging residents to get vaccinated, exercise regularly, and eat a balanced diet.'), + (50, 'A new community theater is opening, offering performances, workshops, and classes for residents of all ages. The theater aims to make the performing arts accessible to everyone.') + ] + +@pytest.fixture(scope="module") +def started_cluster(): + try: + cluster.start() + + for node in [node_s1_r1, node_s1_r2]: + node.query( + """ + CREATE TABLE IF NOT EXISTS local_table(id UInt32, text String) ENGINE = ReplicatedMergeTree('/clickhouse/tables/shard_1/local_table', '{replica}') ORDER BY id; + """.format( + replica=node.name + ) + ) + + for node in [node_s2_r1, node_s2_r2]: + node.query( + """ + CREATE TABLE IF NOT EXISTS local_table(id UInt32, text String) ENGINE = ReplicatedMergeTree('/clickhouse/tables/shard_2/local_table', '{replica}') ORDER BY id; + """.format( + replica=node.name + ) + ) + + node_s1_r1.query("CREATE TABLE distributed_table(id UInt32, text String) ENGINE = Distributed(test_cluster, default, local_table, id);") + node_s1_r1.query("INSERT INTO distributed_table (id, text) VALUES {}".format(", ".join(map(str, data)))) + node_s1_r1.query("ALTER TABLE local_table ON CLUSTER test_cluster ADD INDEX fts_ind text TYPE fts;") + node_s1_r1.query("ALTER TABLE local_table ON CLUSTER test_cluster MATERIALIZE INDEX fts_ind;") + + # Create a MergeTree table with the same data as the distributed table to serve as the ground truth for queries in the distributed TextSearch. + node_s1_r1.query("CREATE TABLE text_table (id UInt32, text String) ENGINE = MergeTree ORDER BY id;") + node_s1_r1.query("INSERT INTO text_table (id, text) VALUES {}".format(", ".join(map(str, data)))) + node_s1_r1.query("ALTER TABLE text_table ADD INDEX fts_ind text TYPE fts;") + node_s1_r1.query("ALTER TABLE text_table MATERIALIZE INDEX fts_ind;") + + time.sleep(2) + + yield cluster + + finally: + cluster.shutdown() + +def test_distributed_text_search(started_cluster): + # Test table function ftsIndex() + assert node_s1_r1.query("SELECT * FROM ftsIndex(default, text_table, text, 'new') ORDER BY field_tokens ASC") == "50\t[(1,1267)]\t[('new',1,30)]\n" + assert node_s1_r1.query("SELECT * FROM cluster(test_cluster, ftsIndex(default, local_table, text, 'new')) ORDER BY field_tokens ASC") == "25\t[(1,630)]\t[('new',1,19)]\n25\t[(1,637)]\t[('new',1,11)]\n" + + # Test query setting dfs_query_then_fetch does not affect the non-Distributed TextSearch results + assert node_s1_r1.query("SELECT id, TextSearch(text, 'city') AS bm25 FROM text_table ORDER BY bm25 DESC LIMIT 5 SETTINGS dfs_query_then_fetch = 1") == node_s1_r1.query("SELECT id, TextSearch(text, 'city') AS bm25 FROM text_table ORDER BY bm25 DESC LIMIT 5 SETTINGS dfs_query_then_fetch = 0") + + # Test Distributed TextSearch results is the same as the non-Distributed TextSearch results + assert node_s1_r1.query("SELECT id, TextSearch(text, 'city') AS bm25 FROM text_table ORDER BY bm25 DESC, id ASC LIMIT 5") == node_s1_r1.query("SELECT id, TextSearch(text, 'city') AS bm25 FROM distributed_table ORDER BY bm25 DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 1") + + # Test Distributed TextSearch results with dfs_query_then_fetch = 0 + assert node_s1_r1.query("SELECT id, TextSearch(text, 'city') AS bm25 FROM distributed_table ORDER BY bm25 DESC, id ASC LIMIT 5 SETTINGS dfs_query_then_fetch = 0") == "18\t1.1643934\n4\t1.1452436\n8\t1.1267136\n16\t1.1267136\n30\t1.1087735\n"