Skip to content

Commit

Permalink
Support applying BlockMaxWand algorithm to PhraseDocIterator (#2369)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Now BlockMaxWandIterator can accept both TermDocIterator and
PhraseDocIterator as children

Issue link:#1320

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
- [x] Performance Improvement
  • Loading branch information
yangzq50 authored Dec 13, 2024
1 parent 861c142 commit de0f9b3
Show file tree
Hide file tree
Showing 11 changed files with 245 additions and 97 deletions.
4 changes: 2 additions & 2 deletions example/fulltext_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
r"Bloom filter", # OR multiple terms
r'"Bloom filter"', # phrase: adjacent multiple terms
r"space efficient", # OR multiple terms
r"space\-efficient", # Escape reserved character '-', equivalent to: `space efficient`
r'"space\-efficient"', # phrase and escape reserved character, equivalent to: `"space efficient"`
r"space\:efficient", # Escape reserved character ':', equivalent to: `space efficient`
r'"space\:efficient"', # phrase and escape reserved character, equivalent to: `"space efficient"`
r'"harmful chemical"~10', # sloppy phrase, refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query-phrase.html
]
for question in questions:
Expand Down
2 changes: 1 addition & 1 deletion example/fulltext_search_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
r"羽毛球", # single term
r'"羽毛球锦标赛"', # phrase: adjacent multiple terms
r"2018年世界羽毛球锦标赛在哪个城市举办?", # OR multiple terms
r"high\-tech", # Escape reserved character '-'
r"high\:tech", # Escape reserved character ':'
r'"high tech"', # phrase: adjacent multiple terms
r'"high-tech"', # phrase: adjacent multiple terms
r"graphics card", # OR multiple terms
Expand Down
42 changes: 42 additions & 0 deletions src/storage/invertedindex/search/blockmax_leaf_iterator.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module blockmax_leaf_iterator;

import stl;
import internal_types;
import doc_iterator;

namespace infinity {

export class BlockMaxLeafIterator : public DocIterator {
public:
virtual RowID BlockMinPossibleDocID() const = 0;

virtual RowID BlockLastDocID() const = 0;

virtual float BlockMaxBM25Score() = 0;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
virtual bool NextShallow(RowID doc_id) = 0;

virtual float BM25Score() = 0;
};

} // namespace infinity
42 changes: 24 additions & 18 deletions src/storage/invertedindex/search/blockmax_wand_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module blockmax_wand_iterator;
import stl;
import third_party;
import index_defines;
import term_doc_iterator;
import blockmax_leaf_iterator;
import multi_doc_iterator;
import internal_types;
import logger;
Expand All @@ -29,18 +29,24 @@ import infinity_exception;
namespace infinity {

BlockMaxWandIterator::~BlockMaxWandIterator() {
String msg = "BlockMaxWandIterator pivot_history: ";
SizeT num_history = pivot_history_.size();
for (SizeT i=0; i<num_history; i++) {
auto &p = pivot_history_[i];
u32 pivot = std::get<0>(p);
u64 row_id = std::get<1>(p);
float score = std::get<2>(p);
//oss << " (" << pivot << ", " << row_id << ", " << score << ")";
msg += fmt::format(" ({}, {}, {:6f})", pivot, row_id, score);
if (SHOULD_LOG_TRACE()) {
String msg = "BlockMaxWandIterator pivot_history: ";
SizeT num_history = pivot_history_.size();
for (SizeT i = 0; i < num_history; i++) {
auto &p = pivot_history_[i];
u32 pivot = std::get<0>(p);
u64 row_id = std::get<1>(p);
float score = std::get<2>(p);
//oss << " (" << pivot << ", " << row_id << ", " << score << ")";
msg += fmt::format(" ({}, {}, {:6f})", pivot, row_id, score);
}
msg += fmt::format("\nnext_sort_cnt_ {}, next_it0_docid_mismatch_cnt_ {}, next_sum_score_low_cnt_ {}, next_sum_score_bm_low_cnt_ {}",
next_sort_cnt_,
next_it0_docid_mismatch_cnt_,
next_sum_score_low_cnt_,
next_sum_score_bm_low_cnt_);
LOG_TRACE(msg);
}
msg += fmt::format("\nnext_sort_cnt_ {}, next_it0_docid_mismatch_cnt_ {}, next_sum_score_low_cnt_ {}, next_sum_score_bm_low_cnt_ {}", next_sort_cnt_, next_it0_docid_mismatch_cnt_, next_sum_score_low_cnt_, next_sum_score_bm_low_cnt_);
LOG_TRACE(msg);
}

BlockMaxWandIterator::BlockMaxWandIterator(Vector<UniquePtr<DocIterator>> &&iterators)
Expand All @@ -49,9 +55,9 @@ BlockMaxWandIterator::BlockMaxWandIterator(Vector<UniquePtr<DocIterator>> &&iter
estimate_iterate_cost_ = {};
SizeT num_iterators = children_.size();
for (SizeT i = 0; i < num_iterators; i++){
TermDocIterator *tdi = dynamic_cast<TermDocIterator *>(children_[i].get());
BlockMaxLeafIterator *tdi = dynamic_cast<BlockMaxLeafIterator *>(children_[i].get());
if (tdi == nullptr) {
UnrecoverableError("BMW only supports TermDocIterator");
UnrecoverableError("BMW only supports BlockMaxLeafIterator");
}
bm25_score_upper_bound_ += tdi->BM25ScoreUpperBound();
estimate_iterate_cost_ += tdi->GetEstimateIterateCost();
Expand Down Expand Up @@ -101,10 +107,10 @@ bool BlockMaxWandIterator::Next(RowID doc_id){
});
// remove exhausted lists
for (int i = int(num_iterators) - 1; i >= 0 && sorted_iterators_[i]->DocID() == INVALID_ROWID; i--) {
if (SHOULD_LOG_DEBUG()) {
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
sorted_iterators_[i]->PrintTree(oss, "Exhaused: ", true);
LOG_DEBUG(oss.str());
LOG_TRACE(oss.str());
}
bm25_score_upper_bound_ -= sorted_iterators_[i]->BM25ScoreUpperBound();
sorted_iterators_.pop_back();
Expand Down Expand Up @@ -142,10 +148,10 @@ bool BlockMaxWandIterator::Next(RowID doc_id){
if (ok) [[likely]] {
sum_score_bm += sorted_iterators_[i]->BlockMaxBM25Score();
} else {
if (SHOULD_LOG_DEBUG()) {
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
sorted_iterators_[i]->PrintTree(oss, "Exhausted: ", true);
LOG_DEBUG(oss.str());
LOG_TRACE(oss.str());
}
sorted_iterators_.erase(sorted_iterators_.begin() + i);
num_iterators = sorted_iterators_.size();
Expand Down
6 changes: 3 additions & 3 deletions src/storage/invertedindex/search/blockmax_wand_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export module blockmax_wand_iterator;
import stl;
import index_defines;
import doc_iterator;
import term_doc_iterator;
import blockmax_leaf_iterator;
import multi_doc_iterator;
import internal_types;

Expand Down Expand Up @@ -50,8 +50,8 @@ private:
RowID common_block_min_possible_doc_id_{}; // not always exist
RowID common_block_last_doc_id_{};
float common_block_max_bm25_score_{};
Vector<TermDocIterator *> sorted_iterators_; // sort by DocID(), in ascending order
Vector<TermDocIterator *> backup_iterators_;
Vector<BlockMaxLeafIterator *> sorted_iterators_; // sort by DocID(), in ascending order
Vector<BlockMaxLeafIterator *> backup_iterators_;
SizeT pivot_;
// bm25 score cache
bool bm25_score_cached_ = false;
Expand Down
70 changes: 69 additions & 1 deletion src/storage/invertedindex/search/phrase_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module;

#include <cassert>
#include <iostream>
#include <vector>

module phrase_doc_iterator;

Expand Down Expand Up @@ -31,6 +32,8 @@ PhraseDocIterator::PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters,
estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq());
}
estimate_iterate_cost_ = {1, estimate_doc_freq_};
block_max_bm25_score_cache_part_info_end_ids_.resize(pos_iters_.size(), INVALID_ROWID);
block_max_bm25_score_cache_part_info_vals_.resize(pos_iters_.size());
}

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader) {
Expand All @@ -41,9 +44,12 @@ void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&col
column_length_reader_ = std::move(column_length_reader);
u64 total_df = column_length_reader_->GetTotalDF();
float avg_column_len = column_length_reader_->GetAvgColumnLength();
float smooth_idf = std::log(1.0F + (total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F));
float smooth_idf = std::log1p((total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F));
bm25_common_score_ = weight_ * smooth_idf * (k1 + 1.0F);
bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len);
f1 = k1 * (1.0F - b);
f2 = k1 * b / avg_column_len;
f3 = f2 * std::numeric_limits<u16>::max();
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
oss << "TermDocIterator: ";
Expand Down Expand Up @@ -80,13 +86,74 @@ bool PhraseDocIterator::Next(const RowID doc_id) {
bool found = GetPhraseMatchData();
if (found && (threshold_ <= 0.0f || BM25Score() > threshold_)) {
doc_id_ = target_doc_id;
UpdateBlockRangeDocID();
return true;
}
++target_doc_id;
}
}
}

void PhraseDocIterator::UpdateBlockRangeDocID() {
RowID min_doc_id = 0;
RowID max_doc_id = INVALID_ROWID;
for (const auto &it : pos_iters_) {
min_doc_id = std::max(min_doc_id, it->BlockLowestPossibleDocID());
max_doc_id = std::min(max_doc_id, it->BlockLastDocID());
}
block_min_possible_doc_id_ = min_doc_id;
block_last_doc_id_ = max_doc_id;
}

float PhraseDocIterator::BlockMaxBM25Score() {
if (const auto last_doc_id = BlockLastDocID(); last_doc_id != block_max_bm25_score_cache_end_id_) {
block_max_bm25_score_cache_end_id_ = last_doc_id;
// bm25_common_score_ / (1.0F + k1 * ((1.0F - b) / block_max_tf + b / block_max_percentage / avg_column_len));
// block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + f1 / block_max_tf + f3 / block_max_percentage_u16);
float div_add_min = std::numeric_limits<float>::max();
for (SizeT i = 0; i < pos_iters_.size(); ++i) {
const auto *iter = pos_iters_[i].get();
float current_div_add_min = {};
if (const auto iter_block_last_doc_id = iter->BlockLastDocID();
iter_block_last_doc_id == block_max_bm25_score_cache_part_info_end_ids_[i]) {
current_div_add_min = block_max_bm25_score_cache_part_info_vals_[i];
} else {
block_max_bm25_score_cache_part_info_end_ids_[i] = iter_block_last_doc_id;
const auto [block_max_tf, block_max_percentage_u16] = iter->GetBlockMaxInfo();
current_div_add_min = f1 / block_max_tf + f3 / block_max_percentage_u16;
block_max_bm25_score_cache_part_info_vals_[i] = current_div_add_min;
}
div_add_min = std::min(div_add_min, current_div_add_min);
}
block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + div_add_min);
}
return block_max_bm25_score_cache_;
}

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool PhraseDocIterator::NextShallow(RowID doc_id) {
if (threshold_ > BM25ScoreUpperBound()) [[unlikely]] {
doc_id_ = INVALID_ROWID;
return false;
}
while (true) {
for (const auto &iter : pos_iters_) {
if (!iter->SkipTo(doc_id)) {
doc_id_ = INVALID_ROWID;
return false;
}
}
UpdateBlockRangeDocID();
if (threshold_ <= 0.0f || BlockMaxBM25Score() > threshold_) {
return true;
}
doc_id = BlockLastDocID() + 1;
}
}

float PhraseDocIterator::BM25Score() {
if (doc_id_ == bm25_score_cache_docid_) [[unlikely]] {
return bm25_score_cache_;
Expand All @@ -112,6 +179,7 @@ void PhraseDocIterator::PrintTree(std::ostream &os, const String &prefix, bool i
}
os << ")";
os << " (doc_freq: " << GetDocFreq() << ")";
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << '\n';
}

Expand Down
23 changes: 21 additions & 2 deletions src/storage/invertedindex/search/phrase_doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import posting_iterator;
import index_defines;
import column_length_io;
import parse_fulltext_options;
import blockmax_leaf_iterator;

namespace infinity {

export class PhraseDocIterator final : public DocIterator {
export class PhraseDocIterator final : public BlockMaxLeafIterator {
public:
PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, float weight, u32 slop, FulltextSimilarity ft_similarity);

Expand All @@ -32,7 +33,21 @@ public:

bool Next(RowID doc_id) override;

float BM25Score();
RowID BlockMinPossibleDocID() const override { return block_min_possible_doc_id_; }

RowID BlockLastDocID() const override { return block_last_doc_id_; }

void UpdateBlockRangeDocID();

float BlockMaxBM25Score() override;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool NextShallow(RowID doc_id) override;

float BM25Score() override;

float Score() override {
switch (ft_similarity_) {
Expand Down Expand Up @@ -86,6 +101,10 @@ private:
UniquePtr<FullTextColumnLengthReader> column_length_reader_ = nullptr;
float block_max_bm25_score_cache_ = 0.0f;
RowID block_max_bm25_score_cache_end_id_ = INVALID_ROWID;
Vector<RowID> block_max_bm25_score_cache_part_info_end_ids_;
Vector<float> block_max_bm25_score_cache_part_info_vals_;
RowID block_min_possible_doc_id_ = INVALID_ROWID;
RowID block_last_doc_id_ = INVALID_ROWID;

float tf_ = 0.0f; // current doc_id_'s tf
u32 estimate_doc_freq_{0}; // estimated at the beginning
Expand Down
Loading

0 comments on commit de0f9b3

Please sign in to comment.