Skip to content

Commit

Permalink
add bmw_iterator_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
yangzq50 committed Dec 13, 2024
1 parent 38f274e commit 6cd5f2b
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 8 deletions.
42 changes: 42 additions & 0 deletions src/storage/invertedindex/search/blockmax_leaf_iterator.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module blockmax_leaf_iterator;

import stl;
import internal_types;
import doc_iterator;

namespace infinity {

export class BlockMaxLeafIterator : public DocIterator {
public:
virtual RowID BlockMinPossibleDocID() const = 0;

virtual RowID BlockLastDocID() const = 0;

virtual float BlockMaxBM25Score() = 0;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
virtual bool NextShallow(RowID doc_id) = 0;

virtual float BM25Score() = 0;
};

} // namespace infinity
68 changes: 68 additions & 0 deletions src/storage/invertedindex/search/phrase_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module;

#include <cassert>
#include <iostream>
#include <vector>

module phrase_doc_iterator;

Expand Down Expand Up @@ -31,6 +32,8 @@ PhraseDocIterator::PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters,
estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq());
}
estimate_iterate_cost_ = {1, estimate_doc_freq_};
block_max_bm25_score_cache_part_info_end_ids_.resize(pos_iters_.size(), INVALID_ROWID);
block_max_bm25_score_cache_part_info_vals_.resize(pos_iters_.size());
}

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader) {
Expand All @@ -44,6 +47,9 @@ void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&col
float smooth_idf = std::log1p((total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F));
bm25_common_score_ = weight_ * smooth_idf * (k1 + 1.0F);
bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len);
f1 = k1 * (1.0F - b);
f2 = k1 * b / avg_column_len;
f3 = f2 * std::numeric_limits<u16>::max();
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
oss << "TermDocIterator: ";
Expand Down Expand Up @@ -80,13 +86,74 @@ bool PhraseDocIterator::Next(const RowID doc_id) {
bool found = GetPhraseMatchData();
if (found && (threshold_ <= 0.0f || BM25Score() > threshold_)) {
doc_id_ = target_doc_id;
UpdateBlockRangeDocID();
return true;
}
++target_doc_id;
}
}
}

void PhraseDocIterator::UpdateBlockRangeDocID() {
RowID min_doc_id = 0;
RowID max_doc_id = INVALID_ROWID;
for (const auto &it : pos_iters_) {
min_doc_id = std::max(min_doc_id, it->BlockLowestPossibleDocID());
max_doc_id = std::min(max_doc_id, it->BlockLastDocID());
}
block_min_possible_doc_id_ = min_doc_id;
block_last_doc_id_ = max_doc_id;
}

float PhraseDocIterator::BlockMaxBM25Score() {
if (const auto last_doc_id = BlockLastDocID(); last_doc_id != block_max_bm25_score_cache_end_id_) {
block_max_bm25_score_cache_end_id_ = last_doc_id;
// bm25_common_score_ / (1.0F + k1 * ((1.0F - b) / block_max_tf + b / block_max_percentage / avg_column_len));
// block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + f1 / block_max_tf + f3 / block_max_percentage_u16);
float div_add_min = std::numeric_limits<float>::max();
for (SizeT i = 0; i < pos_iters_.size(); ++i) {
const auto *iter = pos_iters_[i].get();
float current_div_add_min = {};
if (const auto iter_block_last_doc_id = iter->BlockLastDocID();
iter_block_last_doc_id == block_max_bm25_score_cache_part_info_end_ids_[i]) {
current_div_add_min = block_max_bm25_score_cache_part_info_vals_[i];
} else {
block_max_bm25_score_cache_part_info_end_ids_[i] = iter_block_last_doc_id;
const auto [block_max_tf, block_max_percentage_u16] = iter->GetBlockMaxInfo();
current_div_add_min = f1 / block_max_tf + f3 / block_max_percentage_u16;
block_max_bm25_score_cache_part_info_vals_[i] = current_div_add_min;
}
div_add_min = std::min(div_add_min, current_div_add_min);
}
block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + div_add_min);
}
return block_max_bm25_score_cache_;
}

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool PhraseDocIterator::NextShallow(RowID doc_id) {
if (threshold_ > BM25ScoreUpperBound()) [[unlikely]] {
doc_id_ = INVALID_ROWID;
return false;
}
while (true) {
for (const auto &iter : pos_iters_) {
if (!iter->SkipTo(doc_id)) {
doc_id_ = INVALID_ROWID;
return false;
}
}
UpdateBlockRangeDocID();
if (threshold_ <= 0.0f || BlockMaxBM25Score() > threshold_) {
return true;
}
doc_id = BlockLastDocID() + 1;
}
}

float PhraseDocIterator::BM25Score() {
if (doc_id_ == bm25_score_cache_docid_) [[unlikely]] {
return bm25_score_cache_;
Expand All @@ -112,6 +179,7 @@ void PhraseDocIterator::PrintTree(std::ostream &os, const String &prefix, bool i
}
os << ")";
os << " (doc_freq: " << GetDocFreq() << ")";
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << '\n';
}

Expand Down
23 changes: 21 additions & 2 deletions src/storage/invertedindex/search/phrase_doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import posting_iterator;
import index_defines;
import column_length_io;
import parse_fulltext_options;
import blockmax_leaf_iterator;

namespace infinity {

export class PhraseDocIterator final : public DocIterator {
export class PhraseDocIterator final : public BlockMaxLeafIterator {
public:
PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, float weight, u32 slop, FulltextSimilarity ft_similarity);

Expand All @@ -32,7 +33,21 @@ public:

bool Next(RowID doc_id) override;

float BM25Score();
RowID BlockMinPossibleDocID() const override { return block_min_possible_doc_id_; }

RowID BlockLastDocID() const override { return block_last_doc_id_; }

void UpdateBlockRangeDocID();

float BlockMaxBM25Score() override;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool NextShallow(RowID doc_id) override;

float BM25Score() override;

float Score() override {
switch (ft_similarity_) {
Expand Down Expand Up @@ -86,6 +101,10 @@ private:
UniquePtr<FullTextColumnLengthReader> column_length_reader_ = nullptr;
float block_max_bm25_score_cache_ = 0.0f;
RowID block_max_bm25_score_cache_end_id_ = INVALID_ROWID;
Vector<RowID> block_max_bm25_score_cache_part_info_end_ids_;
Vector<float> block_max_bm25_score_cache_part_info_vals_;
RowID block_min_possible_doc_id_ = INVALID_ROWID;
RowID block_last_doc_id_ = INVALID_ROWID;

float tf_ = 0.0f; // current doc_id_'s tf
u32 estimate_doc_freq_{0}; // estimated at the beginning
Expand Down
13 changes: 7 additions & 6 deletions src/storage/invertedindex/search/term_doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import doc_iterator;
import column_length_io;
import third_party;
import parse_fulltext_options;
import blockmax_leaf_iterator;

namespace infinity {

export class TermDocIterator final : public DocIterator {
export class TermDocIterator final : public BlockMaxLeafIterator {
public:
TermDocIterator(UniquePtr<PostingIterator> &&iter, u64 column_id, float weight, FulltextSimilarity ft_similarity);

Expand All @@ -48,15 +49,15 @@ public:

void InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader);

RowID BlockMinPossibleDocID() const { return iter_->BlockLowestPossibleDocID(); }
RowID BlockLastDocID() const { return iter_->BlockLastDocID(); }
float BlockMaxBM25Score();
RowID BlockMinPossibleDocID() const override { return iter_->BlockLowestPossibleDocID(); }
RowID BlockLastDocID() const override { return iter_->BlockLastDocID(); }
float BlockMaxBM25Score() override;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool NextShallow(RowID doc_id);
bool NextShallow(RowID doc_id) override;

// Overriden methods
DocIteratorType GetType() const override { return DocIteratorType::kTermDocIterator; }
Expand All @@ -65,7 +66,7 @@ public:

bool Next(RowID doc_id) override;

float BM25Score();
float BM25Score() override;

float Score() override {
switch (ft_similarity_) {
Expand Down

0 comments on commit 6cd5f2b

Please sign in to comment.