From c1901a7d05372e00fbd9c1c29d4c83aebeb16ae7 Mon Sep 17 00:00:00 2001 From: yangzq50 <58433399+yangzq50@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:29:09 +0800 Subject: [PATCH] Initial support for boolean similarity (#2144) ### What problem does this PR solve? Initial support for boolean similarity Issue link:#2139 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring - [x] Test cases --- .../fulltext/fulltext_benchmark.cpp | 27 +++--- .../local_infinity/infinity_benchmark.cpp | 45 ++++------ src/executor/operator/physical_match.cpp | 13 +-- src/executor/operator/physical_match.cppm | 2 + src/executor/physical_planner.cpp | 1 + src/planner/bound_select_statement.cpp | 13 +++ src/planner/node/logical_match.cppm | 1 + ...r_expression_push_down_indexscanfilter.cpp | 17 +++- .../index_scan/index_filter_evaluators.cpp | 16 ++-- .../index_scan/index_filter_evaluators.cppm | 7 +- .../invertedindex/search/and_iterator.cpp | 18 ++-- .../invertedindex/search/and_iterator.cppm | 9 +- .../invertedindex/search/and_not_iterator.cpp | 2 +- .../search/and_not_iterator.cppm | 2 +- .../search/blockmax_wand_iterator.cppm | 4 +- .../invertedindex/search/doc_iterator.cppm | 2 +- .../invertedindex/search/filter_iterator.cpp | 7 +- .../invertedindex/search/filter_iterator.cppm | 7 +- .../search/minimum_should_match_iterator.cpp | 14 ++-- .../search/minimum_should_match_iterator.cppm | 8 +- .../invertedindex/search/or_iterator.cpp | 12 +-- .../invertedindex/search/or_iterator.cppm | 8 +- .../search/parse_fulltext_options.cppm | 5 ++ .../search/phrase_doc_iterator.cpp | 7 +- .../search/phrase_doc_iterator.cppm | 18 +++- .../invertedindex/search/query_builder.cpp | 13 ++- .../invertedindex/search/query_builder.cppm | 15 ++-- .../invertedindex/search/query_node.cpp | 84 ++++++++----------- src/storage/invertedindex/search/query_node.h | 45 ++++------ .../search/query_node_module.cppm | 2 +- .../search/score_threshold_iterator.cpp | 2 +- .../search/score_threshold_iterator.cppm | 2 +- .../search/term_doc_iterator.cpp | 4 +- .../search/term_doc_iterator.cppm | 20 ++++- .../invertedindex/search/query_builder.cpp | 24 +++--- .../invertedindex/search/query_match.cpp | 7 +- .../fulltext_minimum_should_match.slt | 6 ++ 37 files changed, 272 insertions(+), 217 deletions(-) diff --git a/benchmark/local_infinity/fulltext/fulltext_benchmark.cpp b/benchmark/local_infinity/fulltext/fulltext_benchmark.cpp index 0b0f3862aa..00a5a6b49d 100644 --- a/benchmark/local_infinity/fulltext/fulltext_benchmark.cpp +++ b/benchmark/local_infinity/fulltext/fulltext_benchmark.cpp @@ -45,6 +45,7 @@ import function_expr; import search_expr; import column_expr; import virtual_store; +import insert_row_expr; using namespace infinity; @@ -163,27 +164,27 @@ void BenchmarkInsert(SharedPtr infinity, const String &db_name, const profiler.Begin(); Vector orig_columns{"id", "title", "text"}; - ConstantExpr *const_expr = nullptr; + UniquePtr const_expr; SizeT num_inserted = 0; while (num_inserted < num_rows) { - Vector *columns = new Vector(orig_columns); - Vector *> *values = new Vector *>(); - values->reserve(insert_batch); + auto insert_rows = new Vector(); + insert_rows->reserve(insert_batch); for (SizeT i = 0; i < insert_batch && (num_inserted + i) < num_rows; i++) { auto &t = batch_cache[num_inserted + i]; - auto value_list = new Vector(columns->size()); - const_expr = new ConstantExpr(LiteralType::kString); + auto insert_row = MakeUnique(); + insert_row->columns_ = orig_columns; + const_expr = MakeUnique(LiteralType::kString); const_expr->str_value_ = std::get<0>(t); - value_list->at(0) = const_expr; - const_expr = new ConstantExpr(LiteralType::kString); + insert_row->values_.emplace_back(std::move(const_expr)); + const_expr = MakeUnique(LiteralType::kString); const_expr->str_value_ = std::get<1>(t); - value_list->at(1) = const_expr; - const_expr = new ConstantExpr(LiteralType::kString); + insert_row->values_.emplace_back(std::move(const_expr)); + const_expr = MakeUnique(LiteralType::kString); const_expr->str_value_ = std::get<2>(t); - value_list->at(2) = const_expr; - values->push_back(value_list); + insert_row->values_.emplace_back(std::move(const_expr)); + insert_rows->push_back(insert_row.release()); } - infinity->Insert(db_name, table_name, columns, values); + infinity->Insert(db_name, table_name, insert_rows); // NOTE: ~InsertStatement() has deleted or freed columns, values, value_list, const_expr, const_expr->str_value_ num_inserted += insert_batch; } diff --git a/benchmark/local_infinity/infinity_benchmark.cpp b/benchmark/local_infinity/infinity_benchmark.cpp index 3fe4226f79..528ca8387f 100644 --- a/benchmark/local_infinity/infinity_benchmark.cpp +++ b/benchmark/local_infinity/infinity_benchmark.cpp @@ -43,6 +43,7 @@ import column_def; import statement_common; import data_type; import virtual_store; +import insert_row_expr; using namespace infinity; @@ -232,22 +233,16 @@ int main() { { auto tims_costing_second = Measurement("Insert", thread_num, total_times, [&](SizeT i, SharedPtr infinity, std::thread::id thread_id) { - Vector *> *values = new Vector *>(); - values->emplace_back(new Vector()); - - Vector *columns = new Vector(); - columns->emplace_back(col_name_1); - columns->emplace_back(col_name_2); - - ConstantExpr *value1 = new ConstantExpr(LiteralType::kInteger); + auto insert_row = MakeUnique(); + insert_row->columns_ = {col_name_1, col_name_2}; + auto value1 = MakeUnique(LiteralType::kInteger); value1->integer_value_ = i; - values->at(0)->emplace_back(value1); - - ConstantExpr *value2 = new ConstantExpr(LiteralType::kInteger); + insert_row->values_.emplace_back(std::move(value1)); + auto value2 = MakeUnique(LiteralType::kInteger); value2->integer_value_ = i; - values->at(0)->emplace_back(value2); - - __attribute__((unused)) auto ignored = infinity->Insert("default_db", "benchmark_test", columns, values); + insert_row->values_.emplace_back(std::move(value2)); + auto insert_rows = new Vector({insert_row.release()}); + [[maybe_unused]] auto ignored = infinity->Insert("default_db", "benchmark_test", insert_rows); }); results.push_back(fmt::format("-> Insert QPS: {}", total_times / tims_costing_second)); } @@ -319,22 +314,16 @@ int main() { { auto tims_costing_second = Measurement("Insert for Select Sort", thread_num, sort_row, [&](SizeT i, SharedPtr infinity, std::thread::id thread_id) { - Vector *> *values = new Vector *>(); - values->emplace_back(new Vector()); - - Vector *columns = new Vector(); - columns->emplace_back(col_name_1); - columns->emplace_back(col_name_2); - - ConstantExpr *value1 = new ConstantExpr(LiteralType::kInteger); + auto insert_row = MakeUnique(); + insert_row->columns_ = {col_name_1, col_name_2}; + auto value1 = MakeUnique(LiteralType::kInteger); value1->integer_value_ = std::rand(); - values->at(0)->emplace_back(value1); - - ConstantExpr *value2 = new ConstantExpr(LiteralType::kInteger); + insert_row->values_.emplace_back(std::move(value1)); + auto value2 = MakeUnique(LiteralType::kInteger); value2->integer_value_ = std::rand(); - values->at(0)->emplace_back(value2); - - __attribute__((unused)) auto ignored = infinity->Insert("default_db", "benchmark_test", columns, values); + insert_row->values_.emplace_back(std::move(value2)); + auto insert_rows = new Vector({insert_row.release()}); + [[maybe_unused]] auto ignored = infinity->Insert("default_db", "benchmark_test", insert_rows); }); results.push_back(fmt::format("-> Insert for Sort Time: {}s", tims_costing_second)); } diff --git a/src/executor/operator/physical_match.cpp b/src/executor/operator/physical_match.cpp index 0d2c93b964..a5c956e219 100644 --- a/src/executor/operator/physical_match.cpp +++ b/src/executor/operator/physical_match.cpp @@ -97,7 +97,7 @@ void ExecuteFTSearch(UniquePtr &et_iter, FullTextScoreResultHeap &r break; } RowID id = et_iter->DocID(); - float et_score = et_iter->BM25Score(); + float et_score = et_iter->Score(); if (SHOULD_LOG_DEBUG()) { OStringStream oss; et_iter->PrintTree(oss, "", true); @@ -161,7 +161,7 @@ bool PhysicalMatch::ExecuteInnerHomebrewed(QueryContext *query_context, Operator // 2 build query iterator // result - FullTextQueryContext full_text_query_context; + FullTextQueryContext full_text_query_context(ft_similarity_, minimum_should_match_option_); u32 result_count = 0; const float *score_result = nullptr; const RowID *row_id_result = nullptr; @@ -182,7 +182,8 @@ bool PhysicalMatch::ExecuteInnerHomebrewed(QueryContext *query_context, Operator full_text_query_context.query_tree_ = MakeUnique(common_query_filter_.get(), std::move(query_tree_)); if (use_block_max_iter) { - et_iter = query_builder.CreateSearch(full_text_query_context, early_term_algo_, minimum_should_match_option_); + full_text_query_context.early_term_algo_ = early_term_algo_; + et_iter = query_builder.CreateSearch(full_text_query_context); // et_iter is nullptr if fulltext index is present but there's no data if (et_iter != nullptr) { et_iter->UpdateScoreThreshold(std::max(begin_threshold_, score_threshold_)); @@ -193,7 +194,8 @@ bool PhysicalMatch::ExecuteInnerHomebrewed(QueryContext *query_context, Operator } } if (use_ordinary_iter) { - doc_iterator = query_builder.CreateSearch(full_text_query_context, EarlyTermAlgo::kNaive, minimum_should_match_option_); + full_text_query_context.early_term_algo_ = EarlyTermAlgo::kNaive; + doc_iterator = query_builder.CreateSearch(full_text_query_context); if (doc_iterator && score_threshold_ > 0.0f) { auto new_doc_iter = MakeUnique(std::move(doc_iterator), score_threshold_); doc_iterator = std::move(new_doc_iter); @@ -351,6 +353,7 @@ PhysicalMatch::PhysicalMatch(const u64 id, const SharedPtr &common_query_filter, MinimumShouldMatchOption &&minimum_should_match_option, const f32 score_threshold, + const FulltextSimilarity ft_similarity, const u64 match_table_index, SharedPtr> load_metas, const bool cache_result) @@ -358,7 +361,7 @@ PhysicalMatch::PhysicalMatch(const u64 id, base_table_ref_(std::move(base_table_ref)), match_expr_(std::move(match_expr)), index_reader_(std::move(index_reader)), query_tree_(std::move(query_tree)), begin_threshold_(begin_threshold), early_term_algo_(early_term_algo), top_n_(top_n), common_query_filter_(common_query_filter), minimum_should_match_option_(std::move(minimum_should_match_option)), - score_threshold_(score_threshold) {} + score_threshold_(score_threshold), ft_similarity_(ft_similarity) {} PhysicalMatch::~PhysicalMatch() = default; diff --git a/src/executor/operator/physical_match.cppm b/src/executor/operator/physical_match.cppm index 7d3bddc847..b49bd0c5e7 100644 --- a/src/executor/operator/physical_match.cppm +++ b/src/executor/operator/physical_match.cppm @@ -55,6 +55,7 @@ public: const SharedPtr &common_query_filter, MinimumShouldMatchOption &&minimum_should_match_option, f32 score_threshold, + FulltextSimilarity ft_similarity, u64 match_table_index, SharedPtr> load_metas, bool cache_result); @@ -113,6 +114,7 @@ private: // for minimum_should_match MinimumShouldMatchOption minimum_should_match_option_{}; f32 score_threshold_{}; + FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25}; bool ExecuteInner(QueryContext *query_context, OperatorState *operator_state); bool ExecuteInnerHomebrewed(QueryContext *query_context, OperatorState *operator_state); diff --git a/src/executor/physical_planner.cpp b/src/executor/physical_planner.cpp index 4117849a92..1c225de9ed 100644 --- a/src/executor/physical_planner.cpp +++ b/src/executor/physical_planner.cpp @@ -963,6 +963,7 @@ UniquePtr PhysicalPlanner::BuildMatch(const SharedPtrcommon_query_filter_, std::move(logical_match->minimum_should_match_option_), logical_match->score_threshold_, + logical_match->ft_similarity_, logical_match->TableIndex(), logical_operator->load_metas(), true /*cache_result*/); diff --git a/src/planner/bound_select_statement.cpp b/src/planner/bound_select_statement.cpp index 537f28916f..70678e3e7d 100644 --- a/src/planner/bound_select_statement.cpp +++ b/src/planner/bound_select_statement.cpp @@ -267,6 +267,19 @@ SharedPtr BoundSelectStatement::BuildPlan(QueryContext *query_conte match_node->score_threshold_ = DataType::StringToValue(iter->second); } + // option: similarity + if (iter = search_ops.options_.find("similarity"); iter != search_ops.options_.end()) { + String ft_sim = iter->second; + ToLower(ft_sim); + if (ft_sim == "bm25") { + match_node->ft_similarity_ = FulltextSimilarity::kBM25; + } else if (ft_sim == "boolean") { + match_node->ft_similarity_ = FulltextSimilarity::kBoolean; + } else { + RecoverableError(Status::SyntaxError(R"(similarity option must be "BM25" or "boolean".)")); + } + } + SearchDriver search_driver(column2analyzer, default_field, query_operator_option); UniquePtr query_tree = search_driver.ParseSingleWithFields(match_node->match_expr_->fields_, match_node->match_expr_->matching_text_); diff --git a/src/planner/node/logical_match.cppm b/src/planner/node/logical_match.cppm index 163e5cc9c7..e55d40b65b 100644 --- a/src/planner/node/logical_match.cppm +++ b/src/planner/node/logical_match.cppm @@ -67,6 +67,7 @@ public: SharedPtr common_query_filter_{}; MinimumShouldMatchOption minimum_should_match_option_{}; f32 score_threshold_{}; + FulltextSimilarity ft_similarity_{FulltextSimilarity::kBM25}; }; } // namespace infinity diff --git a/src/planner/optimizer/index_scan/filter_expression_push_down_indexscanfilter.cpp b/src/planner/optimizer/index_scan/filter_expression_push_down_indexscanfilter.cpp index 36e0075474..4d88208ed7 100644 --- a/src/planner/optimizer/index_scan/filter_expression_push_down_indexscanfilter.cpp +++ b/src/planner/optimizer/index_scan/filter_expression_push_down_indexscanfilter.cpp @@ -460,6 +460,7 @@ class IndexScanFilterExpressionPushDownMethod { UniquePtr query_tree; MinimumShouldMatchOption minimum_should_match_option; f32 score_threshold = {}; + FulltextSimilarity ft_similarity = FulltextSimilarity::kBM25; { const Map &column2analyzer = index_reader.GetColumn2Analyzer(); SearchOptions search_ops(filter_fulltext_expr->options_text_); @@ -513,6 +514,19 @@ class IndexScanFilterExpressionPushDownMethod { score_threshold = DataType::StringToValue(iter->second); } + // option: similarity + if (iter = search_ops.options_.find("similarity"); iter != search_ops.options_.end()) { + String ft_sim = iter->second; + ToLower(ft_sim); + if (ft_sim == "bm25") { + ft_similarity = FulltextSimilarity::kBM25; + } else if (ft_sim == "boolean") { + ft_similarity = FulltextSimilarity::kBoolean; + } else { + RecoverableError(Status::SyntaxError(R"(similarity option must be "BM25" or "boolean".)")); + } + } + SearchDriver search_driver(column2analyzer, default_field, query_operator_option); query_tree = search_driver.ParseSingleWithFields(filter_fulltext_expr->fields_, filter_fulltext_expr->matching_text_); if (!query_tree) { @@ -525,7 +539,8 @@ class IndexScanFilterExpressionPushDownMethod { std::move(index_reader), std::move(query_tree), std::move(minimum_should_match_option), - score_threshold); + score_threshold, + ft_similarity); } case Enum::kAndExpr: { Vector> candidates; diff --git a/src/planner/optimizer/index_scan/index_filter_evaluators.cpp b/src/planner/optimizer/index_scan/index_filter_evaluators.cpp index 170e11990b..796970d2fc 100644 --- a/src/planner/optimizer/index_scan/index_filter_evaluators.cpp +++ b/src/planner/optimizer/index_scan/index_filter_evaluators.cpp @@ -57,7 +57,8 @@ void AddToFulltextEvaluator(UniquePtr &target_full } else { if (target_fulltext_evaluator->HaveMinimumShouldMatchOption() || input->HaveMinimumShouldMatchOption() || target_fulltext_evaluator->score_threshold_ > 0.0f || input->score_threshold_ > 0.0f || - target_fulltext_evaluator->early_term_algo_ != input->early_term_algo_) { + target_fulltext_evaluator->early_term_algo_ != input->early_term_algo_ || + target_fulltext_evaluator->ft_similarity_ != input->ft_similarity_) { // put into others other_children_evaluators.push_back(std::move(input)); } else { @@ -474,7 +475,8 @@ Bitmask IndexFilterEvaluatorFulltext::Evaluate(const SegmentID segment_id, const result.SetAllFalse(); const RowID begin_rowid(segment_id, 0); const RowID end_rowid(segment_id, segment_row_count); - auto ft_iter = query_tree_->CreateSearch(table_entry_, index_reader_, early_term_algo_, minimum_should_match_); + const CreateSearchParams params{table_entry_, &index_reader_, early_term_algo_, ft_similarity_, minimum_should_match_}; + auto ft_iter = query_tree_->CreateSearch(params); if (ft_iter && score_threshold_ > 0.0f) { auto new_ft_iter = MakeUnique(std::move(ft_iter), score_threshold_); ft_iter = std::move(new_ft_iter); @@ -512,10 +514,12 @@ Bitmask IndexFilterEvaluatorAND::Evaluate(const SegmentID segment_id, const Segm if (!fulltext_evaluator_->after_optimize_.test(std::memory_order_acquire)) { UnrecoverableError(std::format("{}: Not optimized!", __func__)); } - auto ft_iter = fulltext_evaluator_->query_tree_->CreateSearch(fulltext_evaluator_->table_entry_, - fulltext_evaluator_->index_reader_, - fulltext_evaluator_->early_term_algo_, - fulltext_evaluator_->minimum_should_match_); + const CreateSearchParams params{fulltext_evaluator_->table_entry_, + &(fulltext_evaluator_->index_reader_), + fulltext_evaluator_->early_term_algo_, + fulltext_evaluator_->ft_similarity_, + fulltext_evaluator_->minimum_should_match_}; + auto ft_iter = fulltext_evaluator_->query_tree_->CreateSearch(params); if (ft_iter && fulltext_evaluator_->score_threshold_ > 0.0f) { auto new_ft_iter = MakeUnique(std::move(ft_iter), fulltext_evaluator_->score_threshold_); ft_iter = std::move(new_ft_iter); diff --git a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm index 75b4d496e0..d3f6f2ff26 100644 --- a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm +++ b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm @@ -93,6 +93,7 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator { u32 minimum_should_match_ = 0; std::atomic_flag after_optimize_ = {}; f32 score_threshold_ = {}; + FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25; IndexFilterEvaluatorFulltext(const FilterFulltextExpression *src_filter_fulltext_expression, const TableEntry *table_entry, @@ -100,10 +101,12 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator { IndexReader &&index_reader, UniquePtr &&query_tree, MinimumShouldMatchOption &&minimum_should_match_option, - const f32 score_threshold) + const f32 score_threshold, + const FulltextSimilarity ft_similarity) : IndexFilterEvaluator(Type::kFulltextIndex), src_filter_fulltext_expressions_({src_filter_fulltext_expression}), table_entry_(table_entry), early_term_algo_(early_term_algo), index_reader_(std::move(index_reader)), query_tree_(std::move(query_tree)), - minimum_should_match_option_(std::move(minimum_should_match_option)), score_threshold_(score_threshold) {} + minimum_should_match_option_(std::move(minimum_should_match_option)), score_threshold_(std::max(score_threshold, 0.0f)), + ft_similarity_(ft_similarity) {} Bitmask Evaluate(SegmentID segment_id, SegmentOffset segment_row_count, Txn *txn) const override; bool HaveMinimumShouldMatchOption() const { return !minimum_should_match_option_.empty(); } void OptimizeQueryTree(); diff --git a/src/storage/invertedindex/search/and_iterator.cpp b/src/storage/invertedindex/search/and_iterator.cpp index 0fe8ec6652..6021caf214 100644 --- a/src/storage/invertedindex/search/and_iterator.cpp +++ b/src/storage/invertedindex/search/and_iterator.cpp @@ -78,12 +78,12 @@ bool AndIterator::Next(const RowID doc_id) { float sum_score = 0.0f; for (SizeT i = 0; i < children_.size(); i++) { const auto &it = children_[i]; - sum_score += it->BM25Score(); + sum_score += it->Score(); } if (sum_score > threshold_) { doc_id_ = target_doc_id; - bm25_score_cache_ = sum_score; - bm25_score_cache_docid_ = doc_id_; + score_cache_ = sum_score; + score_cache_docid_ = doc_id_; return true; } ++target_doc_id; @@ -91,17 +91,17 @@ bool AndIterator::Next(const RowID doc_id) { } } -float AndIterator::BM25Score() { - if (bm25_score_cache_docid_ == doc_id_) { - return bm25_score_cache_; +float AndIterator::Score() { + if (score_cache_docid_ == doc_id_) { + return score_cache_; } float sum_score = 0.0f; for (SizeT i = 0; i < children_.size(); i++) { const auto &it = children_[i]; - sum_score += it->BM25Score(); + sum_score += it->Score(); } - bm25_score_cache_docid_ = doc_id_; - bm25_score_cache_ = sum_score; + score_cache_docid_ = doc_id_; + score_cache_ = sum_score; return sum_score; } diff --git a/src/storage/invertedindex/search/and_iterator.cppm b/src/storage/invertedindex/search/and_iterator.cppm index 405dadf72b..99a5570231 100644 --- a/src/storage/invertedindex/search/and_iterator.cppm +++ b/src/storage/invertedindex/search/and_iterator.cppm @@ -35,18 +35,19 @@ public: /* pure virtual methods implementation */ bool Next(RowID doc_id) override; - float BM25Score() override; + float Score() override; void UpdateScoreThreshold(float threshold) override; u32 MatchCount() const override; private: - // bm25 score cache - RowID bm25_score_cache_docid_ = INVALID_ROWID; - float bm25_score_cache_ = 0.0f; + // score cache + RowID score_cache_docid_ = INVALID_ROWID; + float score_cache_ = 0.0f; // for minimum_should_match u32 fixed_match_count_ = 0; Vector dyn_match_ids_{}; }; + } // namespace infinity \ No newline at end of file diff --git a/src/storage/invertedindex/search/and_not_iterator.cpp b/src/storage/invertedindex/search/and_not_iterator.cpp index 938977c182..ab3eb766df 100644 --- a/src/storage/invertedindex/search/and_not_iterator.cpp +++ b/src/storage/invertedindex/search/and_not_iterator.cpp @@ -55,7 +55,7 @@ bool AndNotIterator::Next(RowID doc_id) { return doc_id != INVALID_ROWID; } -float AndNotIterator::BM25Score() { return children_[0]->BM25Score(); } +float AndNotIterator::Score() { return children_[0]->Score(); } void AndNotIterator::UpdateScoreThreshold(float threshold) { children_[0]->UpdateScoreThreshold(threshold); } diff --git a/src/storage/invertedindex/search/and_not_iterator.cppm b/src/storage/invertedindex/search/and_not_iterator.cppm index 245bfb8488..1e1f4f5a8b 100644 --- a/src/storage/invertedindex/search/and_not_iterator.cppm +++ b/src/storage/invertedindex/search/and_not_iterator.cppm @@ -34,7 +34,7 @@ public: /* pure virtual methods implementation */ bool Next(RowID doc_id) override; - float BM25Score() override; + float Score() override; void UpdateScoreThreshold(float threshold) override; diff --git a/src/storage/invertedindex/search/blockmax_wand_iterator.cppm b/src/storage/invertedindex/search/blockmax_wand_iterator.cppm index a3a1a1681d..9678fec919 100644 --- a/src/storage/invertedindex/search/blockmax_wand_iterator.cppm +++ b/src/storage/invertedindex/search/blockmax_wand_iterator.cppm @@ -39,7 +39,9 @@ public: bool Next(RowID doc_id) override; - float BM25Score() override; + float BM25Score(); + + float Score() override { return BM25Score(); } u32 MatchCount() const override; diff --git a/src/storage/invertedindex/search/doc_iterator.cppm b/src/storage/invertedindex/search/doc_iterator.cppm index 13a8074037..59359db6c0 100644 --- a/src/storage/invertedindex/search/doc_iterator.cppm +++ b/src/storage/invertedindex/search/doc_iterator.cppm @@ -83,7 +83,7 @@ public: // If has_blockmax is true, it ensures its BM25 score be larger than current threshold. virtual bool Next(RowID doc_id) = 0; - virtual float BM25Score() = 0; + virtual float Score() = 0; virtual void UpdateScoreThreshold(float threshold) = 0; diff --git a/src/storage/invertedindex/search/filter_iterator.cpp b/src/storage/invertedindex/search/filter_iterator.cpp index 3194ee1c59..bb8bff5e63 100644 --- a/src/storage/invertedindex/search/filter_iterator.cpp +++ b/src/storage/invertedindex/search/filter_iterator.cpp @@ -42,14 +42,11 @@ void FilterIterator::PrintTree(std::ostream &os, const String &prefix, const boo query_iterator_->PrintTree(os, next_prefix, true); } -UniquePtr FilterQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - const EarlyTermAlgo early_term_algo, - const u32 minimum_should_match) const { +UniquePtr FilterQueryNode::CreateSearch(const CreateSearchParams params) const { assert(common_query_filter_ != nullptr); if (!common_query_filter_->AlwaysTrue() && common_query_filter_->filter_result_count_ == 0) return nullptr; - auto search_iter = query_tree_->CreateSearch(table_entry, index_reader, early_term_algo, minimum_should_match); + auto search_iter = query_tree_->CreateSearch(params); if (!search_iter) { return nullptr; } diff --git a/src/storage/invertedindex/search/filter_iterator.cppm b/src/storage/invertedindex/search/filter_iterator.cppm index b5d1befd76..f8b34fbc60 100644 --- a/src/storage/invertedindex/search/filter_iterator.cppm +++ b/src/storage/invertedindex/search/filter_iterator.cppm @@ -52,7 +52,7 @@ public: } } - float BM25Score() override { return query_iterator_->BM25Score(); } + float Score() override { return query_iterator_->Score(); } void UpdateScoreThreshold(const float threshold) override { query_iterator_->UpdateScoreThreshold(threshold); } @@ -88,10 +88,7 @@ export struct FilterQueryNode final : public QueryNode { void PushDownWeight(const float factor) override { MultiplyWeight(factor); } - UniquePtr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - u32 minimum_should_match) const override; + UniquePtr CreateSearch(CreateSearchParams params) const override; void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override; diff --git a/src/storage/invertedindex/search/minimum_should_match_iterator.cpp b/src/storage/invertedindex/search/minimum_should_match_iterator.cpp index 768fd68f37..73f745260c 100644 --- a/src/storage/invertedindex/search/minimum_should_match_iterator.cpp +++ b/src/storage/invertedindex/search/minimum_should_match_iterator.cpp @@ -48,7 +48,7 @@ MinimumShouldMatchIterator::MinimumShouldMatchIterator(Vector 1"); } tail_heap_.resize(minimum_should_match_ - 1u); - bm25_score_cache_docid_ = INVALID_ROWID; + score_cache_docid_ = INVALID_ROWID; bm25_score_upper_bound_ = 0.0f; estimate_iterate_cost_ = {}; for (const auto &child : children_) { @@ -123,9 +123,9 @@ bool MinimumShouldMatchIterator::Next(RowID doc_id) { } } -float MinimumShouldMatchIterator::BM25Score() { - if (bm25_score_cache_docid_ == doc_id_) { - return bm25_score_cache_; +float MinimumShouldMatchIterator::Score() { + if (score_cache_docid_ == doc_id_) { + return score_cache_; } while (tail_size_) { // advance tail @@ -139,10 +139,10 @@ float MinimumShouldMatchIterator::BM25Score() { } float sum_score = 0; for (const auto idx : lead_) { - sum_score += children_[idx]->BM25Score(); + sum_score += children_[idx]->Score(); } - bm25_score_cache_docid_ = doc_id_; - bm25_score_cache_ = sum_score; + score_cache_docid_ = doc_id_; + score_cache_ = sum_score; return sum_score; } diff --git a/src/storage/invertedindex/search/minimum_should_match_iterator.cppm b/src/storage/invertedindex/search/minimum_should_match_iterator.cppm index e234c21407..f9a824f347 100644 --- a/src/storage/invertedindex/search/minimum_should_match_iterator.cppm +++ b/src/storage/invertedindex/search/minimum_should_match_iterator.cppm @@ -36,7 +36,7 @@ public: bool Next(RowID doc_id) override; - float BM25Score() override; + float Score() override; u32 MatchCount() const override; @@ -53,9 +53,9 @@ private: Vector tail_heap_{}; u32 tail_size_ = 0; - // bm25 score cache - RowID bm25_score_cache_docid_ = {}; - float bm25_score_cache_ = {}; + // score cache + RowID score_cache_docid_ = {}; + float score_cache_ = {}; }; export template T> diff --git a/src/storage/invertedindex/search/or_iterator.cpp b/src/storage/invertedindex/search/or_iterator.cpp index 9645d0936a..3d6cb00e06 100644 --- a/src/storage/invertedindex/search/or_iterator.cpp +++ b/src/storage/invertedindex/search/or_iterator.cpp @@ -84,18 +84,18 @@ bool OrIterator::Next(const RowID doc_id) { return doc_id_ != INVALID_ROWID; } -float OrIterator::BM25Score() { - if (bm25_score_cache_docid_ == doc_id_) { - return bm25_score_cache_; +float OrIterator::Score() { + if (score_cache_docid_ == doc_id_) { + return score_cache_; } float sum_score = 0; for (const auto &child : children_) { if (child->DocID() == doc_id_) { - sum_score += child->BM25Score(); + sum_score += child->Score(); } } - bm25_score_cache_docid_ = doc_id_; - bm25_score_cache_ = sum_score; + score_cache_docid_ = doc_id_; + score_cache_ = sum_score; return sum_score; } diff --git a/src/storage/invertedindex/search/or_iterator.cppm b/src/storage/invertedindex/search/or_iterator.cppm index fba53cc8c0..4c85a1e339 100644 --- a/src/storage/invertedindex/search/or_iterator.cppm +++ b/src/storage/invertedindex/search/or_iterator.cppm @@ -56,7 +56,7 @@ public: /* pure virtual methods implementation */ bool Next(RowID doc_id) override; - float BM25Score() override; + float Score() override; void UpdateScoreThreshold(float threshold) override; @@ -68,9 +68,9 @@ private: const DocIterator *GetDocIterator(u32 i) const { return children_[i].get(); } DocIteratorHeap heap_; - // bm25 score cache - RowID bm25_score_cache_docid_ = INVALID_ROWID; - float bm25_score_cache_ = 0.0f; + // score cache + RowID score_cache_docid_ = INVALID_ROWID; + float score_cache_ = 0.0f; }; } // namespace infinity diff --git a/src/storage/invertedindex/search/parse_fulltext_options.cppm b/src/storage/invertedindex/search/parse_fulltext_options.cppm index b86147f9a0..c30f144d39 100644 --- a/src/storage/invertedindex/search/parse_fulltext_options.cppm +++ b/src/storage/invertedindex/search/parse_fulltext_options.cppm @@ -37,4 +37,9 @@ export MinimumShouldMatchOption ParseMinimumShouldMatchOption(std::string_view i export u32 GetMinimumShouldMatchParameter(const MinimumShouldMatchOption &option_vec, u32 leaf_count); +export enum class FulltextSimilarity { + kBM25, + kBoolean, +}; + } // namespace infinity diff --git a/src/storage/invertedindex/search/phrase_doc_iterator.cpp b/src/storage/invertedindex/search/phrase_doc_iterator.cpp index 04d0534d40..8a0b2208d9 100644 --- a/src/storage/invertedindex/search/phrase_doc_iterator.cpp +++ b/src/storage/invertedindex/search/phrase_doc_iterator.cpp @@ -15,8 +15,11 @@ import logger; namespace infinity { -PhraseDocIterator::PhraseDocIterator(Vector> &&iters, const float weight, const u32 slop) - : pos_iters_(std::move(iters)), weight_(weight), slop_(slop) { +PhraseDocIterator::PhraseDocIterator(Vector> &&iters, + const float weight, + const u32 slop, + const FulltextSimilarity ft_similarity) + : pos_iters_(std::move(iters)), weight_(weight), slop_(slop), ft_similarity_(ft_similarity) { doc_freq_ = 0; phrase_freq_ = 0; if (pos_iters_.size()) { diff --git a/src/storage/invertedindex/search/phrase_doc_iterator.cppm b/src/storage/invertedindex/search/phrase_doc_iterator.cppm index f8daaf85ab..35f80546c7 100644 --- a/src/storage/invertedindex/search/phrase_doc_iterator.cppm +++ b/src/storage/invertedindex/search/phrase_doc_iterator.cppm @@ -9,11 +9,13 @@ import third_party; import posting_iterator; import index_defines; import column_length_io; +import parse_fulltext_options; namespace infinity { + export class PhraseDocIterator final : public DocIterator { public: - PhraseDocIterator(Vector> &&iters, float weight, u32 slop = 0); + PhraseDocIterator(Vector> &&iters, float weight, u32 slop, FulltextSimilarity ft_similarity); inline u32 GetDocFreq() const { return doc_freq_; } @@ -30,7 +32,18 @@ public: bool Next(RowID doc_id) override; - float BM25Score() override; + float BM25Score(); + + float Score() override { + switch (ft_similarity_) { + case FulltextSimilarity::kBM25: { + return BM25Score(); + } + case FulltextSimilarity::kBoolean: { + return GetWeight(); + } + } + } void UpdateScoreThreshold(float threshold) override { if (threshold > threshold_) @@ -61,6 +74,7 @@ private: Vector> pos_iters_; float weight_; u32 slop_{}; + const FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25; // for BM25 Score float bm25_score_cache_ = 0.0f; diff --git a/src/storage/invertedindex/search/query_builder.cpp b/src/storage/invertedindex/search/query_builder.cpp index 8e33e69e54..55b75b5f14 100644 --- a/src/storage/invertedindex/search/query_builder.cpp +++ b/src/storage/invertedindex/search/query_builder.cpp @@ -37,23 +37,22 @@ import parse_fulltext_options; namespace infinity { -void QueryBuilder::Init(IndexReader index_reader) { index_reader_ = index_reader; } +void QueryBuilder::Init(IndexReader index_reader) { index_reader_ = std::move(index_reader); } QueryBuilder::~QueryBuilder() {} -UniquePtr QueryBuilder::CreateSearch(FullTextQueryContext &context, - EarlyTermAlgo early_term_algo, - const MinimumShouldMatchOption &minimum_should_match_option) { +UniquePtr QueryBuilder::CreateSearch(FullTextQueryContext &context) { // Optimize the query tree. if (!context.optimized_query_tree_) { context.optimized_query_tree_ = QueryNode::GetOptimizedQueryTree(std::move(context.query_tree_)); - if (!minimum_should_match_option.empty()) { + if (!context.minimum_should_match_option_.empty()) { const auto leaf_count = context.optimized_query_tree_->LeafCount(); - context.minimum_should_match_ = GetMinimumShouldMatchParameter(minimum_should_match_option, leaf_count); + context.minimum_should_match_ = GetMinimumShouldMatchParameter(context.minimum_should_match_option_, leaf_count); } } // Create the iterator from the query tree. - auto result = context.optimized_query_tree_->CreateSearch(table_entry_, index_reader_, early_term_algo, context.minimum_should_match_); + const CreateSearchParams params{table_entry_, &index_reader_, context.early_term_algo_, context.ft_similarity_, context.minimum_should_match_}; + auto result = context.optimized_query_tree_->CreateSearch(params); #ifdef INFINITY_DEBUG { OStringStream oss; diff --git a/src/storage/invertedindex/search/query_builder.cppm b/src/storage/invertedindex/search/query_builder.cppm index 0127ee83e3..8170e4f81f 100644 --- a/src/storage/invertedindex/search/query_builder.cppm +++ b/src/storage/invertedindex/search/query_builder.cppm @@ -24,15 +24,21 @@ import internal_types; import default_values; import base_table_ref; import parse_fulltext_options; +import query_node; namespace infinity { class Txn; -struct QueryNode; + export struct FullTextQueryContext { - UniquePtr query_tree_; - UniquePtr optimized_query_tree_; + UniquePtr query_tree_{}; + UniquePtr optimized_query_tree_{}; + const FulltextSimilarity ft_similarity_{}; + const MinimumShouldMatchOption minimum_should_match_option_{}; u32 minimum_should_match_ = 0; + EarlyTermAlgo early_term_algo_ = EarlyTermAlgo::kNaive; + FullTextQueryContext(const FulltextSimilarity ft_similarity, const MinimumShouldMatchOption &minimum_should_match_option) + : ft_similarity_(ft_similarity), minimum_should_match_option_(minimum_should_match_option) {} }; export class QueryBuilder { @@ -46,8 +52,7 @@ public: const Map &GetColumn2Analyzer() { return index_reader_.GetColumn2Analyzer(); } - UniquePtr - CreateSearch(FullTextQueryContext &context, EarlyTermAlgo early_term_algo, const MinimumShouldMatchOption &minimum_should_match_option); + UniquePtr CreateSearch(FullTextQueryContext &context); private: BaseTableRef* base_table_ref_{nullptr}; diff --git a/src/storage/invertedindex/search/query_node.cpp b/src/storage/invertedindex/search/query_node.cpp index 1fa26e3015..aff687e9e6 100644 --- a/src/storage/invertedindex/search/query_node.cpp +++ b/src/storage/invertedindex/search/query_node.cpp @@ -21,6 +21,7 @@ import term_doc_iterator; import phrase_doc_iterator; import blockmax_wand_iterator; import minimum_should_match_iterator; +import parse_fulltext_options; namespace infinity { @@ -402,12 +403,9 @@ std::unique_ptr AndNotQueryNode::InnerGetNewOptimizedQueryTree() { } // create search iterator -std::unique_ptr TermQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo /*early_term_algo*/, - u32 /*minimum_should_match*/) const { - ColumnID column_id = table_entry->GetColumnIdByName(column_); - ColumnIndexReader *column_index_reader = index_reader.GetColumnIndexReader(column_id); +std::unique_ptr TermQueryNode::CreateSearch(const CreateSearchParams params) const { + ColumnID column_id = params.table_entry->GetColumnIdByName(column_); + ColumnIndexReader *column_index_reader = params.index_reader->GetColumnIndexReader(column_id); if (!column_index_reader) { RecoverableError(Status::SyntaxError(fmt::format(R"(Invalid query statement: Column "{}" has no fulltext index)", column_))); return nullptr; @@ -422,7 +420,7 @@ std::unique_ptr TermQueryNode::CreateSearch(const TableEntry *table if (!posting_iterator) { return nullptr; } - auto search = MakeUnique(std::move(posting_iterator), column_id, GetWeight()); + auto search = MakeUnique(std::move(posting_iterator), column_id, GetWeight(), params.ft_similarity); auto column_length_reader = MakeUnique(column_index_reader); search->InitBM25Info(std::move(column_length_reader)); search->term_ptr_ = &term_; @@ -430,12 +428,9 @@ std::unique_ptr TermQueryNode::CreateSearch(const TableEntry *table return search; } -std::unique_ptr PhraseQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo /*early_term_algo*/, - u32 /*minimum_should_match*/) const { - ColumnID column_id = table_entry->GetColumnIdByName(column_); - ColumnIndexReader *column_index_reader = index_reader.GetColumnIndexReader(column_id); +std::unique_ptr PhraseQueryNode::CreateSearch(const CreateSearchParams params) const { + ColumnID column_id = params.table_entry->GetColumnIdByName(column_); + ColumnIndexReader *column_index_reader = params.index_reader->GetColumnIndexReader(column_id); if (!column_index_reader) { RecoverableError(Status::SyntaxError(fmt::format(R"(Invalid query statement: Column "{}" has no fulltext index)", column_))); return nullptr; @@ -453,7 +448,7 @@ std::unique_ptr PhraseQueryNode::CreateSearch(const TableEntry *tab } posting_iterators.emplace_back(std::move(posting_iterator)); } - auto search = MakeUnique(std::move(posting_iterators), GetWeight(), slop_); + auto search = MakeUnique(std::move(posting_iterators), GetWeight(), slop_, params.ft_similarity); auto column_length_reader = MakeUnique(column_index_reader); search->InitBM25Info(std::move(column_length_reader)); search->terms_ptr_ = &terms_; @@ -461,14 +456,12 @@ std::unique_ptr PhraseQueryNode::CreateSearch(const TableEntry *tab return search; } -std::unique_ptr AndQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - const EarlyTermAlgo early_term_algo, - const u32 minimum_should_match) const { +std::unique_ptr AndQueryNode::CreateSearch(const CreateSearchParams params) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); + const auto next_params = params.RemoveMSM(); for (auto &child : children_) { - auto iter = child->CreateSearch(table_entry, index_reader, early_term_algo, 0); + auto iter = child->CreateSearch(next_params); if (!iter) { // no need to continue if any child is invalid return nullptr; @@ -479,29 +472,27 @@ std::unique_ptr AndQueryNode::CreateSearch(const TableEntry *table_ return nullptr; } else if (sub_doc_iters.size() == 1) { return std::move(sub_doc_iters[0]); - } else if (minimum_should_match <= sub_doc_iters.size()) { + } else if (params.minimum_should_match <= sub_doc_iters.size()) { return MakeUnique(std::move(sub_doc_iters)); } else { - assert(minimum_should_match > 2u); - return MakeUnique>(std::move(sub_doc_iters), minimum_should_match); + assert(params.minimum_should_match > 2u); + return MakeUnique>(std::move(sub_doc_iters), params.minimum_should_match); } } -std::unique_ptr AndNotQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - const EarlyTermAlgo early_term_algo, - const u32 minimum_should_match) const { +std::unique_ptr AndNotQueryNode::CreateSearch(const CreateSearchParams params) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); // check if the first child is a valid query - auto first_iter = children_.front()->CreateSearch(table_entry, index_reader, early_term_algo, minimum_should_match); + auto first_iter = children_.front()->CreateSearch(params); if (!first_iter) { // no need to continue if the first child is invalid return nullptr; } sub_doc_iters.emplace_back(std::move(first_iter)); + const auto next_params = params.RemoveMSM(); for (u32 i = 1; i < children_.size(); ++i) { - auto iter = children_[i]->CreateSearch(table_entry, index_reader, early_term_algo, 0); + auto iter = children_[i]->CreateSearch(next_params); if (iter) { sub_doc_iters.emplace_back(std::move(iter)); } @@ -513,17 +504,15 @@ std::unique_ptr AndNotQueryNode::CreateSearch(const TableEntry *tab } } -std::unique_ptr OrQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - const EarlyTermAlgo early_term_algo, - const u32 minimum_should_match) const { +std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams params) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); bool all_are_term = true; bool all_are_term_or_phrase = true; const QueryNode *only_child_ = nullptr; + const auto next_params = params.RemoveMSM(); for (auto &child : children_) { - if (auto iter = child->CreateSearch(table_entry, index_reader, early_term_algo, 0); iter) { + if (auto iter = child->CreateSearch(next_params); iter) { only_child_ = child.get(); sub_doc_iters.emplace_back(std::move(iter)); if (const auto child_type = child->GetType(); child_type != QueryNodeType::TERM) { @@ -537,40 +526,37 @@ std::unique_ptr OrQueryNode::CreateSearch(const TableEntry *table_e if (sub_doc_iters.empty()) { return nullptr; } else if (sub_doc_iters.size() == 1) { - return only_child_->CreateSearch(table_entry, index_reader, early_term_algo, minimum_should_match); - } else if (all_are_term && early_term_algo == EarlyTermAlgo::kBMW) { - if (minimum_should_match <= 1u) { + return only_child_->CreateSearch(params); + } else if (all_are_term && params.ft_similarity == FulltextSimilarity::kBM25 && params.early_term_algo == EarlyTermAlgo::kBMW) { + if (params.minimum_should_match <= 1u) { return MakeUnique(std::move(sub_doc_iters)); - } else if (minimum_should_match < sub_doc_iters.size()) { - return MakeUnique>(std::move(sub_doc_iters), minimum_should_match); - } else if (minimum_should_match == sub_doc_iters.size()) { + } else if (params.minimum_should_match < sub_doc_iters.size()) { + return MakeUnique>(std::move(sub_doc_iters), params.minimum_should_match); + } else if (params.minimum_should_match == sub_doc_iters.size()) { return MakeUnique(std::move(sub_doc_iters)); } else { return nullptr; } } else if (all_are_term_or_phrase) { - if (minimum_should_match <= 1u) { + if (params.minimum_should_match <= 1u) { return MakeUnique(std::move(sub_doc_iters)); - } else if (minimum_should_match < sub_doc_iters.size()) { - return MakeUnique(std::move(sub_doc_iters), minimum_should_match); - } else if (minimum_should_match == sub_doc_iters.size()) { + } else if (params.minimum_should_match < sub_doc_iters.size()) { + return MakeUnique(std::move(sub_doc_iters), params.minimum_should_match); + } else if (params.minimum_should_match == sub_doc_iters.size()) { return MakeUnique(std::move(sub_doc_iters)); } else { return nullptr; } } else { - if (minimum_should_match <= 1u) { + if (params.minimum_should_match <= 1u) { return MakeUnique(std::move(sub_doc_iters)); } else { - return MakeUnique>(std::move(sub_doc_iters), minimum_should_match); + return MakeUnique>(std::move(sub_doc_iters), params.minimum_should_match); } } } -std::unique_ptr NotQueryNode::CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - u32 minimum_should_match) const { +std::unique_ptr NotQueryNode::CreateSearch(CreateSearchParams) const { UnrecoverableError("NOT query node should be optimized into AND_NOT query node"); return nullptr; } diff --git a/src/storage/invertedindex/search/query_node.h b/src/storage/invertedindex/search/query_node.h index e83d081b3d..b5a6429f7a 100644 --- a/src/storage/invertedindex/search/query_node.h +++ b/src/storage/invertedindex/search/query_node.h @@ -51,6 +51,16 @@ class Scorer; class DocIterator; class EarlyTerminateIterator; enum class EarlyTermAlgo; +enum class FulltextSimilarity; + +struct CreateSearchParams { + const TableEntry *table_entry; + const IndexReader *index_reader; + EarlyTermAlgo early_term_algo; + FulltextSimilarity ft_similarity; + uint32_t minimum_should_match; + [[nodiscard]] CreateSearchParams RemoveMSM() const { return {table_entry, index_reader, early_term_algo, ft_similarity, 0}; } +}; // step 1. get the query tree from parser // step 2. push down the weight to the leaf term node @@ -83,10 +93,7 @@ struct QueryNode { // recursively multiply and push down the weight to the leaf term nodes virtual void PushDownWeight(float factor = 1.0f) = 0; // create the iterator from the query tree, need to be called after optimization - virtual std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const = 0; + virtual std::unique_ptr CreateSearch(CreateSearchParams params) const = 0; // print the query tree, for debugging virtual void PrintTree(std::ostream &os, const std::string &prefix = "", bool is_final = true) const = 0; @@ -102,10 +109,7 @@ struct TermQueryNode : public QueryNode { uint32_t LeafCount() const override { return 1; } void PushDownWeight(float factor) override { MultiplyWeight(factor); } - std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const override; + std::unique_ptr CreateSearch(CreateSearchParams params) const override; void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const override; void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; }; @@ -119,10 +123,7 @@ struct PhraseQueryNode final : public QueryNode { uint32_t LeafCount() const override { return 1; } void PushDownWeight(float factor) override { MultiplyWeight(factor); } - std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const override; + std::unique_ptr CreateSearch(CreateSearchParams params) const override; void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const override; void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; @@ -157,38 +158,26 @@ struct NotQueryNode final : public MultiQueryNode { NotQueryNode() : MultiQueryNode(QueryNodeType::NOT) {} uint32_t LeafCount() const override { return 0; } std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const override; + std::unique_ptr CreateSearch(CreateSearchParams params) const override; }; struct AndQueryNode final : public MultiQueryNode { AndQueryNode() : MultiQueryNode(QueryNodeType::AND) {} std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const override; + std::unique_ptr CreateSearch(CreateSearchParams params) const override; }; struct AndNotQueryNode final : public MultiQueryNode { AndNotQueryNode() : MultiQueryNode(QueryNodeType::AND_NOT) {} uint32_t LeafCount() const override { return children_[0]->LeafCount(); } std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const override; + std::unique_ptr CreateSearch(CreateSearchParams params) const override; }; struct OrQueryNode final : public MultiQueryNode { OrQueryNode() : MultiQueryNode(QueryNodeType::OR) {} std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, - const IndexReader &index_reader, - EarlyTermAlgo early_term_algo, - uint32_t minimum_should_match) const override; + std::unique_ptr CreateSearch(CreateSearchParams params) const override; }; // unimplemented diff --git a/src/storage/invertedindex/search/query_node_module.cppm b/src/storage/invertedindex/search/query_node_module.cppm index 941d7d02d1..6279820433 100644 --- a/src/storage/invertedindex/search/query_node_module.cppm +++ b/src/storage/invertedindex/search/query_node_module.cppm @@ -22,7 +22,7 @@ namespace infinity { export using infinity::QueryNodeType; export using infinity::QueryNodeTypeToString; - +export using infinity::CreateSearchParams; export using infinity::QueryNode; export using infinity::TermQueryNode; export using infinity::MultiQueryNode; diff --git a/src/storage/invertedindex/search/score_threshold_iterator.cpp b/src/storage/invertedindex/search/score_threshold_iterator.cpp index a17bc8a5aa..fac41134fc 100644 --- a/src/storage/invertedindex/search/score_threshold_iterator.cpp +++ b/src/storage/invertedindex/search/score_threshold_iterator.cpp @@ -32,7 +32,7 @@ bool ScoreThresholdIterator::Next(RowID doc_id) { } doc_id = query_iterator_->DocID(); // check score - if (BM25Score() >= score_threshold_) { + if (Score() >= score_threshold_) { doc_id_ = doc_id; return true; } diff --git a/src/storage/invertedindex/search/score_threshold_iterator.cppm b/src/storage/invertedindex/search/score_threshold_iterator.cppm index 5023b2564c..03ec9e0740 100644 --- a/src/storage/invertedindex/search/score_threshold_iterator.cppm +++ b/src/storage/invertedindex/search/score_threshold_iterator.cppm @@ -32,7 +32,7 @@ public: DocIteratorType GetType() const override { return DocIteratorType::kScoreThresholdIterator; } String Name() const override { return "ScoreThresholdIterator"; }; bool Next(RowID doc_id) override; - float BM25Score() override { return query_iterator_->BM25Score(); } + float Score() override { return query_iterator_->Score(); } void UpdateScoreThreshold(const float threshold) override { query_iterator_->UpdateScoreThreshold(threshold); } u32 MatchCount() const override { return query_iterator_->MatchCount(); } void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override; diff --git a/src/storage/invertedindex/search/term_doc_iterator.cpp b/src/storage/invertedindex/search/term_doc_iterator.cpp index 161975c988..3d2203110b 100644 --- a/src/storage/invertedindex/search/term_doc_iterator.cpp +++ b/src/storage/invertedindex/search/term_doc_iterator.cpp @@ -24,8 +24,8 @@ import logger; namespace infinity { -TermDocIterator::TermDocIterator(UniquePtr &&iter, const u64 column_id, const float weight) - : column_id_(column_id), iter_(std::move(iter)), weight_(weight) { +TermDocIterator::TermDocIterator(UniquePtr &&iter, const u64 column_id, const float weight, const FulltextSimilarity ft_similarity) + : column_id_(column_id), iter_(std::move(iter)), weight_(weight), ft_similarity_(ft_similarity) { doc_freq_ = iter_->GetDocFreq(); term_freq_ = 0; estimate_iterate_cost_ = {0, doc_freq_}; diff --git a/src/storage/invertedindex/search/term_doc_iterator.cppm b/src/storage/invertedindex/search/term_doc_iterator.cppm index af8eb8d413..f03d97b536 100644 --- a/src/storage/invertedindex/search/term_doc_iterator.cppm +++ b/src/storage/invertedindex/search/term_doc_iterator.cppm @@ -26,11 +26,13 @@ import internal_types; import doc_iterator; import column_length_io; import third_party; +import parse_fulltext_options; namespace infinity { + export class TermDocIterator final : public DocIterator { public: - TermDocIterator(UniquePtr &&iter, u64 column_id, float weight); + TermDocIterator(UniquePtr &&iter, u64 column_id, float weight, FulltextSimilarity ft_similarity); ~TermDocIterator() override; @@ -61,7 +63,18 @@ public: bool Next(RowID doc_id) override; - float BM25Score() override; + float BM25Score(); + + float Score() override { + switch (ft_similarity_) { + case FulltextSimilarity::kBM25: { + return BM25Score(); + } + case FulltextSimilarity::kBoolean: { + return GetWeight(); + } + } + } void UpdateScoreThreshold(float threshold) override { if (threshold > threshold_) @@ -77,7 +90,6 @@ public: const String *column_name_ptr_ = nullptr; private: - Pair GetScoreData(); u32 doc_freq_ = 0; @@ -85,6 +97,7 @@ private: UniquePtr iter_; float weight_ = 1.0f; // changed in MultiplyWeight() u64 term_freq_; + const FulltextSimilarity ft_similarity_ = FulltextSimilarity::kBM25; // for BM25 Score float bm25_score_cache_ = 0.0f; @@ -109,4 +122,5 @@ private: u32 block_skip_cnt_ = 0; u32 block_skip_cnt_inner_ = 0; }; + } // namespace infinity diff --git a/src/unit_test/storage/invertedindex/search/query_builder.cpp b/src/unit_test/storage/invertedindex/search/query_builder.cpp index b6eea79f0c..d868120610 100644 --- a/src/unit_test/storage/invertedindex/search/query_builder.cpp +++ b/src/unit_test/storage/invertedindex/search/query_builder.cpp @@ -58,7 +58,7 @@ class MockVectorDocIterator : public DocIterator { return false; } - float BM25Score() override { return 0.1f; } + float Score() override { return 0.1f; } void UpdateScoreThreshold(float threshold) override {} @@ -94,7 +94,7 @@ struct MockQueryNode : public TermQueryNode { } void PushDownWeight(float factor) final { MultiplyWeight(factor); } - std::unique_ptr CreateSearch(const TableEntry *, const IndexReader &, EarlyTermAlgo, u32) const override { + std::unique_ptr CreateSearch(CreateSearchParams) const override { return MakeUnique(std::move(doc_ids_), term_, column_); } void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const final { @@ -200,11 +200,12 @@ TEST_F(QueryBuilderTest, test_and) { static_cast(and_root.get())->PrintTree(oss); LOG_INFO(oss.str()); // apply query builder - FullTextQueryContext context; + FullTextQueryContext context(FulltextSimilarity::kBM25, MinimumShouldMatchOption{}); + context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); + UniquePtr result_iter = builder.CreateSearch(context); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; @@ -270,11 +271,12 @@ TEST_F(QueryBuilderTest, test_or) { static_cast(or_root.get())->PrintTree(oss); LOG_INFO(oss.str()); // apply query builder - FullTextQueryContext context; + FullTextQueryContext context(FulltextSimilarity::kBM25, MinimumShouldMatchOption{}); + context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(or_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); + UniquePtr result_iter = builder.CreateSearch(context); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; @@ -346,11 +348,12 @@ TEST_F(QueryBuilderTest, test_and_not) { static_cast(and_not_root.get())->PrintTree(oss); LOG_INFO(oss.str()); // apply query builder - FullTextQueryContext context; + FullTextQueryContext context(FulltextSimilarity::kBM25, MinimumShouldMatchOption{}); + context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); + UniquePtr result_iter = builder.CreateSearch(context); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; @@ -428,11 +431,12 @@ TEST_F(QueryBuilderTest, test_and_not2) { static_cast(and_not_root.get())->PrintTree(oss); LOG_INFO(oss.str()); // apply query builder - FullTextQueryContext context; + FullTextQueryContext context(FulltextSimilarity::kBM25, MinimumShouldMatchOption{}); + context.early_term_algo_ = EarlyTermAlgo::kNaive; context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); + UniquePtr result_iter = builder.CreateSearch(context); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; diff --git a/src/unit_test/storage/invertedindex/search/query_match.cpp b/src/unit_test/storage/invertedindex/search/query_match.cpp index 6991afcf4a..ed06d2237d 100644 --- a/src/unit_test/storage/invertedindex/search/query_match.cpp +++ b/src/unit_test/storage/invertedindex/search/query_match.cpp @@ -338,16 +338,17 @@ void QueryMatchTest::QueryMatch(const String &db_name, Status status = Status::ParseMatchExprFailed(match_expr->fields_, match_expr->matching_text_); RecoverableError(status); } - FullTextQueryContext full_text_query_context; + FullTextQueryContext full_text_query_context(FulltextSimilarity::kBM25, MinimumShouldMatchOption{}); + full_text_query_context.early_term_algo_ = EarlyTermAlgo::kNaive; full_text_query_context.query_tree_ = std::move(query_tree); - UniquePtr doc_iterator = query_builder.CreateSearch(full_text_query_context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); + UniquePtr doc_iterator = query_builder.CreateSearch(full_text_query_context); RowID iter_row_id = doc_iterator.get() == nullptr ? INVALID_ROWID : (doc_iterator->Next(), doc_iterator->DocID()); if (iter_row_id == INVALID_ROWID) { fmt::print("iter_row_id is INVALID_ROWID\n"); } else { do { - auto score = doc_iterator->BM25Score(); + auto score = doc_iterator->Score(); fmt::print("iter_row_id = {}, score = {}\n", iter_row_id.ToUint64(), score); doc_iterator->Next(); iter_row_id = doc_iterator->DocID(); diff --git a/test/sql/dql/fulltext/fulltext_minimum_should_match.slt b/test/sql/dql/fulltext/fulltext_minimum_should_match.slt index 227aefeb66..5851504ad5 100644 --- a/test/sql/dql/fulltext/fulltext_minimum_should_match.slt +++ b/test/sql/dql/fulltext/fulltext_minimum_should_match.slt @@ -47,6 +47,12 @@ SELECT *, SCORE() FROM ft_minimum_should_match SEARCH MATCH TEXT ('doc', 'second 4 another second text xxx 1.079068 1 first text 0.366141 +query I +SELECT *, SCORE() FROM ft_minimum_should_match SEARCH MATCH TEXT ('doc', 'second text multiple', 'topn=10;similarity=boolean;threshold=1.5'); +---- +2 second text multiple 3.000000 +4 another second text xxx 2.000000 + query I SELECT * FROM ft_minimum_should_match WHERE filter_fulltext('doc', 'second text', 'threshold=0.3'); ----