Skip to content

Commit

Permalink
Init scorer framework (#693)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add init scorer, BM25 is initially added
TODO:
Column length reader and writer has not been implemented.
More rankers are required

Issue link:#421

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
yingfeng authored Mar 1, 2024
1 parent 97dbbc0 commit a217862
Show file tree
Hide file tree
Showing 14 changed files with 311 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/storage/invertedindex/common/external_sort_merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ void SortMerger<KeyType, LenType>::Run() {
IASSERT(out_f);
IASSERT(fwrite(&count_, sizeof(u64), 1, out_f) == 1);

Thread *out_thread[OUT_BUF_NUM_];
Vector<Thread *> out_thread(OUT_BUF_NUM_);
for (u32 i = 0; i < OUT_BUF_NUM_; ++i)
out_thread[i] = new Thread(std::bind(&self_t::Output, this, out_f, i));

Expand Down
26 changes: 26 additions & 0 deletions src/storage/invertedindex/posting_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import in_doc_pos_state;
import multi_posting_decoder;
import segment_posting;
import index_defines;
import match_data;
export module posting_iterator;

namespace infinity {
Expand All @@ -29,6 +30,17 @@ public:

void SeekPosition(pos_t pos, pos_t &result);

tf_t GetCurrentTF() { return state_.GetTermFreq(); }

docpayload_t GetCurrentDocPayload() {
if (posting_option_.HasDocPayload()) {
DecodeTFBuffer();
DecodeDocPayloadBuffer();
return doc_payload_buffer_[GetDocOffsetInBuffer()];
}
return 0;
}

ttf_t GetCurrentTTF() {
if (posting_option_.HasTfList()) {
DecodeTFBuffer();
Expand All @@ -42,6 +54,20 @@ public:

bool HasPosition() const { return posting_option_.HasPositionList(); }

void GetTermMatchData(TermColumnMatchData &match_data) {
DecodeTFBuffer();
DecodeDocPayloadBuffer();
if (need_move_to_current_doc_) {
MoveToCurrentDoc();
}
if (posting_option_.HasTfList()) {
match_data.tf_ = tf_buffer_[GetDocOffsetInBuffer()];
}
if (posting_option_.HasDocPayload()) {
match_data.doc_payload_ = doc_payload_buffer_[GetDocOffsetInBuffer()];
}
}

private:
u32 GetCurrentSeekedDocCount() const { return posting_decoder_->InnerGetSeekedDocCount() + (GetDocOffsetInBuffer() + 1); }

Expand Down
35 changes: 35 additions & 0 deletions src/storage/invertedindex/search/bm25_ranker.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

#include <cmath>

module bm25_ranker;

import stl;

namespace infinity {

constexpr float k1 = 1.2F;
constexpr float b = 0.75F;

BM25Ranker::BM25Ranker(u64 total_df) : total_df_(std::max(total_df, 1UL)) {}

void BM25Ranker::AddTermParam(u64 tf, u64 df, double avg_column_len, u64 column_len) {
float smooth_idf = std::log(1.0F + (total_df_ - df + 0.5F) / (df + 0.5F));
float smooth_tf = (k1 + 1.0F) * tf / (tf + k1 * (1.0F - b + b * column_len / avg_column_len));
score_ += smooth_idf * smooth_tf;
}
} // namespace infinity
35 changes: 35 additions & 0 deletions src/storage/invertedindex/search/bm25_ranker.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module bm25_ranker;

import stl;

namespace infinity {
export class BM25Ranker {
public:
BM25Ranker(u64 total_df);
~BM25Ranker() = default;

void AddTermParam(u64 tf, u64 df, double avg_column_len, u64 column_len);

float GetScore() { return score_; }

private:
float score_{0};
i64 total_df_{0};
};
} // namespace infinity
32 changes: 32 additions & 0 deletions src/storage/invertedindex/search/column_length_io.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module column_length_io;

import stl;
import index_defines;

namespace infinity {
export class ColumnLengthWriter {};

export class ColumnLengthReader {
public:
ColumnLengthReader() = default;
~ColumnLengthReader() = default;

u32 GetColumnLength(u64 column_id, docid_t doc_id) { return 0; }
};
} // namespace infinity
3 changes: 3 additions & 0 deletions src/storage/invertedindex/search/doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import posting_iterator;
import index_defines;
import segment;
import index_config;
import match_data;
namespace infinity {
export class DocIterator {
public:
Expand All @@ -47,6 +48,8 @@ public:

virtual u32 GetDF() const = 0;

virtual bool GetTermMatchData(TermColumnMatchData &match_data, docid_t doc_id) { return false; }

protected:
docid_t doc_id_;
};
Expand Down
80 changes: 80 additions & 0 deletions src/storage/invertedindex/search/match_data.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

module match_data;

import stl;
import term_doc_iterator;
import index_defines;
import bm25_ranker;

namespace infinity {

Scorer::Scorer(u64 num_of_docs) : total_df_(num_of_docs) { column_length_reader_ = MakeUnique<ColumnLengthReader>(); }

u32 Scorer::GetOrSetColumnIndex(u64 column_id) {
if (column_index_map_.find(column_id) == column_index_map_.end()) {
column_index_map_[column_id] = column_counter_;
match_data_.term_columns_.resize(column_counter_ + 1);
return column_counter_++;
} else
return column_index_map_[column_id];
}

void Scorer::InitRanker(const Map<u64, double> &weight_map) {
for (auto it : weight_map) {
u32 column_index = GetOrSetColumnIndex(it.first);
column_weights_.resize(column_index + 1);
column_weights_[column_index] = it.second;
column_ids_.resize(column_index + 1);
column_ids_[column_index] = it.first;
}
for (u32 i = 0; i < column_counter_; ++i) {
avg_column_length_[i] = GetAvgColumnLength(column_ids_[i]);
}
}

double Scorer::GetAvgColumnLength(u64 column_id) {
double length = 0.0F;
// TODO
return length;
}

void Scorer::AddDocIterator(TermDocIterator *iter, u64 column_id) {
u32 column_index = GetOrSetColumnIndex(column_id);
iterators_.resize(column_index + 1);
iterators_[column_index].push_back(iter);
}

float Scorer::Score(docid_t doc_id) {
float score = 0.0F;
for (u32 i = 0; i < column_counter_; i++) {
BM25Ranker ranker(total_df_);
u32 column_len = column_length_reader_->GetColumnLength(column_ids_[i], doc_id);
Vector<TermDocIterator *> &column_iters = iterators_[i];
TermColumnMatchData &column_match_data = match_data_.term_columns_[i];
for (u32 j = 0; j < column_iters.size(); j++) {
if (column_iters[j]->GetTermMatchData(column_match_data, doc_id)) {
ranker.AddTermParam(column_match_data.tf_, column_iters[j]->GetDF(), avg_column_length_[i], column_len);
}
}
auto s = ranker.GetScore();
score += column_weights_[i] * s;
}
return score;
}

} // namespace infinity
70 changes: 70 additions & 0 deletions src/storage/invertedindex/search/match_data.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module match_data;

import stl;
import third_party;
import index_defines;
import column_length_io;

namespace infinity {
export struct TermColumnMatchData {
docid_t doc_id_;
tf_t tf_;
docpayload_t doc_payload_;
};

export struct MatchData {
Vector<TermColumnMatchData> term_columns_;

TermColumnMatchData *ResolveTermColumn(u32 column_sequence) { return &term_columns_[column_sequence]; }
};

class TermDocIterator;
export class Scorer {
public:
Scorer(u64 num_of_docs);

~Scorer() = default;

void InitRanker(const Map<u64, double> &weight_map);

void AddDocIterator(TermDocIterator *iter, u64 column_id);

float Score(docid_t doc_id);

private:
u32 GetOrSetColumnIndex(u64 column_id);

double GetAvgColumnLength(u64 column_id);

struct Hash {
inline u64 operator()(const u64 &val) const { return val; }
};

u64 total_df_;
u32 column_counter_{0};
FlatHashMap<u64, u32, Hash> column_index_map_;
Vector<u64> column_ids_;
Vector<Vector<TermDocIterator *>> iterators_;
Vector<double> column_weights_;
Vector<double> avg_column_length_;
UniquePtr<ColumnLengthReader> column_length_reader_;
MatchData match_data_;
};

} // namespace infinity
11 changes: 7 additions & 4 deletions src/storage/invertedindex/search/query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import and_not_iterator;
import or_iterator;
import column_index_reader;
import posting_iterator;
import match_data;

namespace infinity {

UniquePtr<TermQuery> TermQuery::Optimize(UniquePtr<TermQuery> query) {
Expand All @@ -33,12 +35,13 @@ UniquePtr<TermQuery> TermQuery::Optimize(UniquePtr<TermQuery> query) {
return UniquePtr<TermQuery>(root);
}

UniquePtr<DocIterator> TermQuery::CreateSearch(IndexReader &index_reader) {
UniquePtr<DocIterator> TermQuery::CreateSearch(IndexReader &index_reader, Scorer *scorer) {
ColumnIndexReader *column_index_reader = index_reader.GetColumnIndexReader(column_.column_id_);
PostingIterator *posting_iterator = column_index_reader->Lookup(term_, index_reader.session_pool_.get());
if (posting_iterator == nullptr)
return nullptr;
UniquePtr<TermDocIterator> search = MakeUnique<TermDocIterator>(posting_iterator);
UniquePtr<TermDocIterator> search = MakeUnique<TermDocIterator>(posting_iterator, column_.column_id_);
scorer->AddDocIterator(search.get(), column_.column_id_);
return std::move(search);
}

Expand Down Expand Up @@ -74,11 +77,11 @@ void MultiQuery::Optimize(TermQuery *&self) {
OptimizeSelf();
}

UniquePtr<DocIterator> MultiQuery::CreateSearch(IndexReader &index_reader) {
UniquePtr<DocIterator> MultiQuery::CreateSearch(IndexReader &index_reader, Scorer *scorer) {
Vector<UniquePtr<DocIterator>> sub_doc_iters;
sub_doc_iters.reserve(children_.size());
for (u32 i = 0; i < children_.size(); ++i) {
auto iter = children_[i]->CreateSearch(index_reader);
auto iter = children_[i]->CreateSearch(index_reader, scorer);
if (iter)
sub_doc_iters.push_back(std::move(iter));
}
Expand Down
5 changes: 3 additions & 2 deletions src/storage/invertedindex/search/query.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export module term_queries;

import stl;
import doc_iterator;
import match_data;
import column_index_reader;

namespace infinity {
Expand Down Expand Up @@ -51,7 +52,7 @@ public:

virtual bool IsOr() const { return false; }

virtual UniquePtr<DocIterator> CreateSearch(IndexReader &index_reader);
virtual UniquePtr<DocIterator> CreateSearch(IndexReader &index_reader, Scorer *scorer);

protected:
virtual void NotifyChange() {
Expand Down Expand Up @@ -89,7 +90,7 @@ public:

void Optimize(TermQuery *&self) override;

UniquePtr<DocIterator> CreateSearch(IndexReader &index_reader) override;
UniquePtr<DocIterator> CreateSearch(IndexReader &index_reader, Scorer *scorer) override;

virtual UniquePtr<DocIterator> CreateMultiSearch(Vector<UniquePtr<DocIterator>> sub_doc_iters) = 0;

Expand Down
Loading

0 comments on commit a217862

Please sign in to comment.