From de0f9b352a72e57a0a8f56c17eb0ec98832352c8 Mon Sep 17 00:00:00 2001 From: yangzq50 Date: Fri, 13 Dec 2024 18:51:56 +0800 Subject: [PATCH] Support applying BlockMaxWand algorithm to PhraseDocIterator (#2369) ### What problem does this PR solve? Now BlockMaxWandIterator can accept both TermDocIterator and PhraseDocIterator as children Issue link:#1320 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring - [x] Performance Improvement --- example/fulltext_search.py | 4 +- example/fulltext_search_zh.py | 2 +- .../search/blockmax_leaf_iterator.cppm | 42 ++++++ .../search/blockmax_wand_iterator.cpp | 42 +++--- .../search/blockmax_wand_iterator.cppm | 6 +- .../search/phrase_doc_iterator.cpp | 70 +++++++++- .../search/phrase_doc_iterator.cppm | 23 +++- .../invertedindex/search/query_node.cpp | 128 ++++++++++-------- .../search/term_doc_iterator.cpp | 2 +- .../search/term_doc_iterator.cppm | 13 +- test/sql/dql/fulltext/fulltext.slt | 10 +- 11 files changed, 245 insertions(+), 97 deletions(-) create mode 100644 src/storage/invertedindex/search/blockmax_leaf_iterator.cppm diff --git a/example/fulltext_search.py b/example/fulltext_search.py index 75ad5c962d..ed9a4fb146 100644 --- a/example/fulltext_search.py +++ b/example/fulltext_search.py @@ -81,8 +81,8 @@ r"Bloom filter", # OR multiple terms r'"Bloom filter"', # phrase: adjacent multiple terms r"space efficient", # OR multiple terms - r"space\-efficient", # Escape reserved character '-', equivalent to: `space efficient` - r'"space\-efficient"', # phrase and escape reserved character, equivalent to: `"space efficient"` + r"space\:efficient", # Escape reserved character ':', equivalent to: `space efficient` + r'"space\:efficient"', # phrase and escape reserved character, equivalent to: `"space efficient"` r'"harmful chemical"~10', # sloppy phrase, refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query-phrase.html ] for question in questions: diff --git a/example/fulltext_search_zh.py b/example/fulltext_search_zh.py index d658058c77..42a6516a53 100644 --- a/example/fulltext_search_zh.py +++ b/example/fulltext_search_zh.py @@ -102,7 +102,7 @@ r"羽毛球", # single term r'"羽毛球锦标赛"', # phrase: adjacent multiple terms r"2018年世界羽毛球锦标赛在哪个城市举办?", # OR multiple terms - r"high\-tech", # Escape reserved character '-' + r"high\:tech", # Escape reserved character ':' r'"high tech"', # phrase: adjacent multiple terms r'"high-tech"', # phrase: adjacent multiple terms r"graphics card", # OR multiple terms diff --git a/src/storage/invertedindex/search/blockmax_leaf_iterator.cppm b/src/storage/invertedindex/search/blockmax_leaf_iterator.cppm new file mode 100644 index 0000000000..3c76f8fc16 --- /dev/null +++ b/src/storage/invertedindex/search/blockmax_leaf_iterator.cppm @@ -0,0 +1,42 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module blockmax_leaf_iterator; + +import stl; +import internal_types; +import doc_iterator; + +namespace infinity { + +export class BlockMaxLeafIterator : public DocIterator { +public: + virtual RowID BlockMinPossibleDocID() const = 0; + + virtual RowID BlockLastDocID() const = 0; + + virtual float BlockMaxBM25Score() = 0; + + // Move block cursor to ensure its last_doc_id is no less than given doc_id. + // Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted. + // Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true. + // Caller may invoke BlockMaxBM25Score() after this routine. + virtual bool NextShallow(RowID doc_id) = 0; + + virtual float BM25Score() = 0; +}; + +} // namespace infinity diff --git a/src/storage/invertedindex/search/blockmax_wand_iterator.cpp b/src/storage/invertedindex/search/blockmax_wand_iterator.cpp index b186c563bd..07b50979c1 100644 --- a/src/storage/invertedindex/search/blockmax_wand_iterator.cpp +++ b/src/storage/invertedindex/search/blockmax_wand_iterator.cpp @@ -20,7 +20,7 @@ module blockmax_wand_iterator; import stl; import third_party; import index_defines; -import term_doc_iterator; +import blockmax_leaf_iterator; import multi_doc_iterator; import internal_types; import logger; @@ -29,18 +29,24 @@ import infinity_exception; namespace infinity { BlockMaxWandIterator::~BlockMaxWandIterator() { - String msg = "BlockMaxWandIterator pivot_history: "; - SizeT num_history = pivot_history_.size(); - for (SizeT i=0; i(p); - u64 row_id = std::get<1>(p); - float score = std::get<2>(p); - //oss << " (" << pivot << ", " << row_id << ", " << score << ")"; - msg += fmt::format(" ({}, {}, {:6f})", pivot, row_id, score); + if (SHOULD_LOG_TRACE()) { + String msg = "BlockMaxWandIterator pivot_history: "; + SizeT num_history = pivot_history_.size(); + for (SizeT i = 0; i < num_history; i++) { + auto &p = pivot_history_[i]; + u32 pivot = std::get<0>(p); + u64 row_id = std::get<1>(p); + float score = std::get<2>(p); + //oss << " (" << pivot << ", " << row_id << ", " << score << ")"; + msg += fmt::format(" ({}, {}, {:6f})", pivot, row_id, score); + } + msg += fmt::format("\nnext_sort_cnt_ {}, next_it0_docid_mismatch_cnt_ {}, next_sum_score_low_cnt_ {}, next_sum_score_bm_low_cnt_ {}", + next_sort_cnt_, + next_it0_docid_mismatch_cnt_, + next_sum_score_low_cnt_, + next_sum_score_bm_low_cnt_); + LOG_TRACE(msg); } - msg += fmt::format("\nnext_sort_cnt_ {}, next_it0_docid_mismatch_cnt_ {}, next_sum_score_low_cnt_ {}, next_sum_score_bm_low_cnt_ {}", next_sort_cnt_, next_it0_docid_mismatch_cnt_, next_sum_score_low_cnt_, next_sum_score_bm_low_cnt_); - LOG_TRACE(msg); } BlockMaxWandIterator::BlockMaxWandIterator(Vector> &&iterators) @@ -49,9 +55,9 @@ BlockMaxWandIterator::BlockMaxWandIterator(Vector> &&iter estimate_iterate_cost_ = {}; SizeT num_iterators = children_.size(); for (SizeT i = 0; i < num_iterators; i++){ - TermDocIterator *tdi = dynamic_cast(children_[i].get()); + BlockMaxLeafIterator *tdi = dynamic_cast(children_[i].get()); if (tdi == nullptr) { - UnrecoverableError("BMW only supports TermDocIterator"); + UnrecoverableError("BMW only supports BlockMaxLeafIterator"); } bm25_score_upper_bound_ += tdi->BM25ScoreUpperBound(); estimate_iterate_cost_ += tdi->GetEstimateIterateCost(); @@ -101,10 +107,10 @@ bool BlockMaxWandIterator::Next(RowID doc_id){ }); // remove exhausted lists for (int i = int(num_iterators) - 1; i >= 0 && sorted_iterators_[i]->DocID() == INVALID_ROWID; i--) { - if (SHOULD_LOG_DEBUG()) { + if (SHOULD_LOG_TRACE()) { OStringStream oss; sorted_iterators_[i]->PrintTree(oss, "Exhaused: ", true); - LOG_DEBUG(oss.str()); + LOG_TRACE(oss.str()); } bm25_score_upper_bound_ -= sorted_iterators_[i]->BM25ScoreUpperBound(); sorted_iterators_.pop_back(); @@ -142,10 +148,10 @@ bool BlockMaxWandIterator::Next(RowID doc_id){ if (ok) [[likely]] { sum_score_bm += sorted_iterators_[i]->BlockMaxBM25Score(); } else { - if (SHOULD_LOG_DEBUG()) { + if (SHOULD_LOG_TRACE()) { OStringStream oss; sorted_iterators_[i]->PrintTree(oss, "Exhausted: ", true); - LOG_DEBUG(oss.str()); + LOG_TRACE(oss.str()); } sorted_iterators_.erase(sorted_iterators_.begin() + i); num_iterators = sorted_iterators_.size(); diff --git a/src/storage/invertedindex/search/blockmax_wand_iterator.cppm b/src/storage/invertedindex/search/blockmax_wand_iterator.cppm index 9678fec919..bd6e4a2895 100644 --- a/src/storage/invertedindex/search/blockmax_wand_iterator.cppm +++ b/src/storage/invertedindex/search/blockmax_wand_iterator.cppm @@ -18,7 +18,7 @@ export module blockmax_wand_iterator; import stl; import index_defines; import doc_iterator; -import term_doc_iterator; +import blockmax_leaf_iterator; import multi_doc_iterator; import internal_types; @@ -50,8 +50,8 @@ private: RowID common_block_min_possible_doc_id_{}; // not always exist RowID common_block_last_doc_id_{}; float common_block_max_bm25_score_{}; - Vector sorted_iterators_; // sort by DocID(), in ascending order - Vector backup_iterators_; + Vector sorted_iterators_; // sort by DocID(), in ascending order + Vector backup_iterators_; SizeT pivot_; // bm25 score cache bool bm25_score_cached_ = false; diff --git a/src/storage/invertedindex/search/phrase_doc_iterator.cpp b/src/storage/invertedindex/search/phrase_doc_iterator.cpp index dc131afc18..2651ade930 100644 --- a/src/storage/invertedindex/search/phrase_doc_iterator.cpp +++ b/src/storage/invertedindex/search/phrase_doc_iterator.cpp @@ -2,6 +2,7 @@ module; #include #include +#include module phrase_doc_iterator; @@ -31,6 +32,8 @@ PhraseDocIterator::PhraseDocIterator(Vector> &&iters, estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq()); } estimate_iterate_cost_ = {1, estimate_doc_freq_}; + block_max_bm25_score_cache_part_info_end_ids_.resize(pos_iters_.size(), INVALID_ROWID); + block_max_bm25_score_cache_part_info_vals_.resize(pos_iters_.size()); } void PhraseDocIterator::InitBM25Info(UniquePtr &&column_length_reader) { @@ -41,9 +44,12 @@ void PhraseDocIterator::InitBM25Info(UniquePtr &&col column_length_reader_ = std::move(column_length_reader); u64 total_df = column_length_reader_->GetTotalDF(); float avg_column_len = column_length_reader_->GetAvgColumnLength(); - float smooth_idf = std::log(1.0F + (total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F)); + float smooth_idf = std::log1p((total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F)); bm25_common_score_ = weight_ * smooth_idf * (k1 + 1.0F); bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len); + f1 = k1 * (1.0F - b); + f2 = k1 * b / avg_column_len; + f3 = f2 * std::numeric_limits::max(); if (SHOULD_LOG_TRACE()) { OStringStream oss; oss << "TermDocIterator: "; @@ -80,6 +86,7 @@ bool PhraseDocIterator::Next(const RowID doc_id) { bool found = GetPhraseMatchData(); if (found && (threshold_ <= 0.0f || BM25Score() > threshold_)) { doc_id_ = target_doc_id; + UpdateBlockRangeDocID(); return true; } ++target_doc_id; @@ -87,6 +94,66 @@ bool PhraseDocIterator::Next(const RowID doc_id) { } } +void PhraseDocIterator::UpdateBlockRangeDocID() { + RowID min_doc_id = 0; + RowID max_doc_id = INVALID_ROWID; + for (const auto &it : pos_iters_) { + min_doc_id = std::max(min_doc_id, it->BlockLowestPossibleDocID()); + max_doc_id = std::min(max_doc_id, it->BlockLastDocID()); + } + block_min_possible_doc_id_ = min_doc_id; + block_last_doc_id_ = max_doc_id; +} + +float PhraseDocIterator::BlockMaxBM25Score() { + if (const auto last_doc_id = BlockLastDocID(); last_doc_id != block_max_bm25_score_cache_end_id_) { + block_max_bm25_score_cache_end_id_ = last_doc_id; + // bm25_common_score_ / (1.0F + k1 * ((1.0F - b) / block_max_tf + b / block_max_percentage / avg_column_len)); + // block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + f1 / block_max_tf + f3 / block_max_percentage_u16); + float div_add_min = std::numeric_limits::max(); + for (SizeT i = 0; i < pos_iters_.size(); ++i) { + const auto *iter = pos_iters_[i].get(); + float current_div_add_min = {}; + if (const auto iter_block_last_doc_id = iter->BlockLastDocID(); + iter_block_last_doc_id == block_max_bm25_score_cache_part_info_end_ids_[i]) { + current_div_add_min = block_max_bm25_score_cache_part_info_vals_[i]; + } else { + block_max_bm25_score_cache_part_info_end_ids_[i] = iter_block_last_doc_id; + const auto [block_max_tf, block_max_percentage_u16] = iter->GetBlockMaxInfo(); + current_div_add_min = f1 / block_max_tf + f3 / block_max_percentage_u16; + block_max_bm25_score_cache_part_info_vals_[i] = current_div_add_min; + } + div_add_min = std::min(div_add_min, current_div_add_min); + } + block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + div_add_min); + } + return block_max_bm25_score_cache_; +} + +// Move block cursor to ensure its last_doc_id is no less than given doc_id. +// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted. +// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true. +// Caller may invoke BlockMaxBM25Score() after this routine. +bool PhraseDocIterator::NextShallow(RowID doc_id) { + if (threshold_ > BM25ScoreUpperBound()) [[unlikely]] { + doc_id_ = INVALID_ROWID; + return false; + } + while (true) { + for (const auto &iter : pos_iters_) { + if (!iter->SkipTo(doc_id)) { + doc_id_ = INVALID_ROWID; + return false; + } + } + UpdateBlockRangeDocID(); + if (threshold_ <= 0.0f || BlockMaxBM25Score() > threshold_) { + return true; + } + doc_id = BlockLastDocID() + 1; + } +} + float PhraseDocIterator::BM25Score() { if (doc_id_ == bm25_score_cache_docid_) [[unlikely]] { return bm25_score_cache_; @@ -112,6 +179,7 @@ void PhraseDocIterator::PrintTree(std::ostream &os, const String &prefix, bool i } os << ")"; os << " (doc_freq: " << GetDocFreq() << ")"; + os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")"; os << '\n'; } diff --git a/src/storage/invertedindex/search/phrase_doc_iterator.cppm b/src/storage/invertedindex/search/phrase_doc_iterator.cppm index 35f80546c7..75d8092311 100644 --- a/src/storage/invertedindex/search/phrase_doc_iterator.cppm +++ b/src/storage/invertedindex/search/phrase_doc_iterator.cppm @@ -10,10 +10,11 @@ import posting_iterator; import index_defines; import column_length_io; import parse_fulltext_options; +import blockmax_leaf_iterator; namespace infinity { -export class PhraseDocIterator final : public DocIterator { +export class PhraseDocIterator final : public BlockMaxLeafIterator { public: PhraseDocIterator(Vector> &&iters, float weight, u32 slop, FulltextSimilarity ft_similarity); @@ -32,7 +33,21 @@ public: bool Next(RowID doc_id) override; - float BM25Score(); + RowID BlockMinPossibleDocID() const override { return block_min_possible_doc_id_; } + + RowID BlockLastDocID() const override { return block_last_doc_id_; } + + void UpdateBlockRangeDocID(); + + float BlockMaxBM25Score() override; + + // Move block cursor to ensure its last_doc_id is no less than given doc_id. + // Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted. + // Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true. + // Caller may invoke BlockMaxBM25Score() after this routine. + bool NextShallow(RowID doc_id) override; + + float BM25Score() override; float Score() override { switch (ft_similarity_) { @@ -86,6 +101,10 @@ private: UniquePtr column_length_reader_ = nullptr; float block_max_bm25_score_cache_ = 0.0f; RowID block_max_bm25_score_cache_end_id_ = INVALID_ROWID; + Vector block_max_bm25_score_cache_part_info_end_ids_; + Vector block_max_bm25_score_cache_part_info_vals_; + RowID block_min_possible_doc_id_ = INVALID_ROWID; + RowID block_last_doc_id_ = INVALID_ROWID; float tf_ = 0.0f; // current doc_id_'s tf u32 estimate_doc_freq_{0}; // estimated at the beginning diff --git a/src/storage/invertedindex/search/query_node.cpp b/src/storage/invertedindex/search/query_node.cpp index c7f39fc4d7..d90774b09d 100644 --- a/src/storage/invertedindex/search/query_node.cpp +++ b/src/storage/invertedindex/search/query_node.cpp @@ -508,7 +508,6 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams Vector> sub_doc_iters; Vector> keyword_iters; sub_doc_iters.reserve(children_.size()); - bool all_are_term = true; // describe sub_doc_iters bool all_are_term_or_phrase = true; // describe sub_doc_iters const QueryNode *only_child = nullptr; const auto next_params = params.RemoveMSM(); @@ -519,11 +518,8 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams keyword_iters.emplace_back(std::move(iter)); } else { sub_doc_iters.emplace_back(std::move(iter)); - if (child_type != QueryNodeType::TERM) { - all_are_term = false; - if (child_type != QueryNodeType::PHRASE) { - all_are_term_or_phrase = false; - } + if (child_type != QueryNodeType::TERM && child_type != QueryNodeType::PHRASE) { + all_are_term_or_phrase = false; } } } @@ -531,11 +527,10 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams if (sub_doc_iters.size() < 2) { // 0 or 1 // no need for WAND - all_are_term = false; all_are_term_or_phrase = false; } - const u32 msm_bar = keyword_iters.empty() ? 1u : 0u; - auto GetIterResultT = [&]() -> std::unique_ptr { + auto GetIterResultT = [¶ms, &sub_doc_iters, &keyword_iters]() -> std::unique_ptr { + const u32 msm_bar = keyword_iters.empty() ? 1u : 0u; if (params.minimum_should_match > sub_doc_iters.size()) { return nullptr; } else if (params.minimum_should_match <= msm_bar) { @@ -571,7 +566,7 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams } } }; - auto term_num_threshold = [](const u32 topn) -> u32 { + [[maybe_unused]] auto term_num_threshold = [](const u32 topn) -> u32 { if (topn < 5u) { return std::numeric_limits::max(); } @@ -580,31 +575,46 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams } return 50u / std::log10f(topn); }; + auto term_children_need_batch = [&sub_doc_iters]() -> bool { + u64 total_df = 0u; + u64 df_sum = 0u; + for (const auto &iter : sub_doc_iters) { + if (iter->GetType() == DocIteratorType::kTermDocIterator) { + const auto tdi = static_cast(iter.get()); + total_df = tdi->GetTotalDF(); + df_sum += tdi->GetDocFreq(); + } + } + return df_sum && (df_sum * 5ull >= total_df); + }; if (sub_doc_iters.empty() && keyword_iters.empty()) { return nullptr; } if (sub_doc_iters.size() + keyword_iters.size() == 1) { return only_child->CreateSearch(params, is_top_level); } - if (is_top_level && all_are_term && params.ft_similarity == FulltextSimilarity::kBM25) { + if (is_top_level && all_are_term_or_phrase && params.ft_similarity == FulltextSimilarity::kBM25) { auto choose_algo = EarlyTermAlgo::kNaive; switch (params.early_term_algo) { case EarlyTermAlgo::kAuto: { - if (params.topn > 0u && sub_doc_iters.size() <= term_num_threshold(params.topn)) { + if (params.topn) { + // always prefer BMW choose_algo = EarlyTermAlgo::kBMW; + } else if (term_children_need_batch()) { + // topn == 0, case of filter + choose_algo = EarlyTermAlgo::kBatch; } else { - // check df - const auto total_df = static_cast(sub_doc_iters.front().get())->GetTotalDF(); - u64 df_sum = 0u; - for (const auto &iter : sub_doc_iters) { - df_sum += static_cast(iter.get())->GetDocFreq(); - } - if (df_sum * 5ull < total_df) { - choose_algo = EarlyTermAlgo::kBMW; - } else { - choose_algo = EarlyTermAlgo::kBatch; - } + choose_algo = EarlyTermAlgo::kNaive; + } + /* TODO: now always use BMW + if ((params.topn == 0u || sub_doc_iters.size() > term_num_threshold(params.topn)) && term_children_need_batch()) { + choose_algo = EarlyTermAlgo::kBatch; + } else if (params.topn == 0u) { + choose_algo = EarlyTermAlgo::kNaive; + } else { + choose_algo = EarlyTermAlgo::kBMW; } + */ break; } case EarlyTermAlgo::kBMW: @@ -618,48 +628,50 @@ std::unique_ptr OrQueryNode::CreateSearch(const CreateSearchParams break; } } - if (choose_algo == EarlyTermAlgo::kBMW) { - return GetIterResultT.template operator()(); - } else if (choose_algo == EarlyTermAlgo::kBatch) { - return GetIterResultT.template operator()(); - } else if (choose_algo == EarlyTermAlgo::kNaive) { - return GetIterResultT.template operator()(); + switch (choose_algo) { + case EarlyTermAlgo::kBMW: { + return GetIterResultT.template operator()(); + } + case EarlyTermAlgo::kNaive: { + return GetIterResultT.template operator()(); + } + case EarlyTermAlgo::kBatch: { + assert(params.early_term_algo == EarlyTermAlgo::kAuto || params.early_term_algo == EarlyTermAlgo::kBatch); + // go to next "if" block + break; + } + default: { + UnrecoverableError(fmt::format("{}: Unexpected case!", __func__)); + return nullptr; + } } - UnrecoverableError("Unreachable code"); - return nullptr; } if ((params.early_term_algo == EarlyTermAlgo::kAuto || params.early_term_algo == EarlyTermAlgo::kBatch) && - params.ft_similarity == FulltextSimilarity::kBM25) { - // try to apply batch when possible - // collect all term children info - u64 total_df = 0u; - u64 df_sum = 0u; - for (const auto &iter : sub_doc_iters) { + params.ft_similarity == FulltextSimilarity::kBM25 && term_children_need_batch()) { + // term_iters will be non-empty + Vector> term_iters; + Vector> not_term_iters = std::move(keyword_iters); + for (auto &iter : sub_doc_iters) { if (iter->GetType() == DocIteratorType::kTermDocIterator) { - const auto tdi = static_cast(iter.get()); - total_df = tdi->GetTotalDF(); - df_sum += tdi->GetDocFreq(); - } - } - if (df_sum && (df_sum * 5ull >= total_df)) { - // must have child other than term - Vector> term_iters; - Vector> not_term_iters = std::move(keyword_iters); - for (auto &iter : sub_doc_iters) { - if (iter->GetType() == DocIteratorType::kTermDocIterator) { - term_iters.emplace_back(std::move(iter)); - } else { - not_term_iters.emplace_back(std::move(iter)); - } - } - auto batch_or_iter = MakeUnique(std::move(term_iters)); - not_term_iters.emplace_back(std::move(batch_or_iter)); - if (params.minimum_should_match <= 0) { - return MakeUnique(std::move(not_term_iters)); + term_iters.emplace_back(std::move(iter)); } else { - return MakeUnique>(std::move(not_term_iters), params.minimum_should_match); + not_term_iters.emplace_back(std::move(iter)); } } + if (not_term_iters.empty()) { + assert(all_are_term_or_phrase); + sub_doc_iters = std::move(term_iters); + keyword_iters.clear(); + return GetIterResultT.template operator()(); + } + auto batch_or_iter = MakeUnique(std::move(term_iters)); + not_term_iters.emplace_back(std::move(batch_or_iter)); + // now at least 2 children in not_term_iters + if (params.minimum_should_match <= 0) { + return MakeUnique(std::move(not_term_iters)); + } else { + return MakeUnique>(std::move(not_term_iters), params.minimum_should_match); + } } if (all_are_term_or_phrase) { return GetIterResultT.template operator()(); diff --git a/src/storage/invertedindex/search/term_doc_iterator.cpp b/src/storage/invertedindex/search/term_doc_iterator.cpp index 915734af1a..13fcebcaf2 100644 --- a/src/storage/invertedindex/search/term_doc_iterator.cpp +++ b/src/storage/invertedindex/search/term_doc_iterator.cpp @@ -53,7 +53,7 @@ void TermDocIterator::InitBM25Info(UniquePtr &&colum column_length_reader_ = std::move(column_length_reader); avg_column_len_ = column_length_reader_->GetAvgColumnLength(); total_df_ = column_length_reader_->GetTotalDF(); - const float smooth_idf = std::log(1.0F + (column_length_reader_->GetTotalDF() - doc_freq_ + 0.5F) / (doc_freq_ + 0.5F)); + const float smooth_idf = std::log1p((column_length_reader_->GetTotalDF() - doc_freq_ + 0.5F) / (doc_freq_ + 0.5F)); bm25_common_score_ = weight_ * smooth_idf * (k1 + 1.0F); bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len_); f1 = k1 * (1.0F - b); diff --git a/src/storage/invertedindex/search/term_doc_iterator.cppm b/src/storage/invertedindex/search/term_doc_iterator.cppm index 18bf99d6fe..0d21ac458d 100644 --- a/src/storage/invertedindex/search/term_doc_iterator.cppm +++ b/src/storage/invertedindex/search/term_doc_iterator.cppm @@ -27,10 +27,11 @@ import doc_iterator; import column_length_io; import third_party; import parse_fulltext_options; +import blockmax_leaf_iterator; namespace infinity { -export class TermDocIterator final : public DocIterator { +export class TermDocIterator final : public BlockMaxLeafIterator { public: TermDocIterator(UniquePtr &&iter, u64 column_id, float weight, FulltextSimilarity ft_similarity); @@ -48,15 +49,15 @@ public: void InitBM25Info(UniquePtr &&column_length_reader); - RowID BlockMinPossibleDocID() const { return iter_->BlockLowestPossibleDocID(); } - RowID BlockLastDocID() const { return iter_->BlockLastDocID(); } - float BlockMaxBM25Score(); + RowID BlockMinPossibleDocID() const override { return iter_->BlockLowestPossibleDocID(); } + RowID BlockLastDocID() const override { return iter_->BlockLastDocID(); } + float BlockMaxBM25Score() override; // Move block cursor to ensure its last_doc_id is no less than given doc_id. // Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted. // Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true. // Caller may invoke BlockMaxBM25Score() after this routine. - bool NextShallow(RowID doc_id); + bool NextShallow(RowID doc_id) override; // Overriden methods DocIteratorType GetType() const override { return DocIteratorType::kTermDocIterator; } @@ -65,7 +66,7 @@ public: bool Next(RowID doc_id) override; - float BM25Score(); + float BM25Score() override; float Score() override { switch (ft_similarity_) { diff --git a/test/sql/dql/fulltext/fulltext.slt b/test/sql/dql/fulltext/fulltext.slt index 63b321a15f..32fba39246 100644 --- a/test/sql/dql/fulltext/fulltext.slt +++ b/test/sql/dql/fulltext/fulltext.slt @@ -36,26 +36,26 @@ Anarchism 30-APR-2012 03:25:17.000 0 22.299635 query TTIR rowsort SELECT doctitle, docdate, ROW_ID(), SCORE() FROM sqllogic_test_enwiki SEARCH MATCH TEXT ('body^5', '"social customs"', 'topn=3;block_max=compare') USING INDEXES ('ft_index'); ---- -Anarchism 30-APR-2012 03:25:17.000 6 46.196758 +Anarchism 30-APR-2012 03:25:17.000 6 20.753590 # only phrase query TTIR rowsort SELECT doctitle, docdate, ROW_ID(), SCORE() FROM sqllogic_test_enwiki SEARCH MATCH TEXT ('body^5', '"social customs"', 'topn=3;block_max=compare'); ---- -Anarchism 30-APR-2012 03:25:17.000 6 46.196758 +Anarchism 30-APR-2012 03:25:17.000 6 20.753590 # phrase and term query TTIR rowsort SELECT doctitle, docdate, ROW_ID(), SCORE() FROM sqllogic_test_enwiki SEARCH MATCH TEXT ('body^5', '"social customs" harmful', 'topn=3'); ---- Anarchism 30-APR-2012 03:25:17.000 0 22.299635 -Anarchism 30-APR-2012 03:25:17.000 6 46.196758 +Anarchism 30-APR-2012 03:25:17.000 6 20.753590 # phrase and term query TTIR rowsort -SELECT doctitle, docdate, ROW_ID(), SCORE() FROM sqllogic_test_enwiki SEARCH MATCH TEXT ('body^5', '"social customs" harmful', 'topn=3;threshold=40'); +SELECT doctitle, docdate, ROW_ID(), SCORE() FROM sqllogic_test_enwiki SEARCH MATCH TEXT ('body^5', '"social customs" harmful', 'topn=3;threshold=21.5'); ---- -Anarchism 30-APR-2012 03:25:17.000 6 46.196758 +Anarchism 30-APR-2012 03:25:17.000 0 22.299635 # copy data from csv file query I