Skip to content

Commit

Permalink
Add unittest for inverted index search iterator (#754)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Add unittest for inverted index search iterator

Issue link:#641

### Type of change

- [x] Refactoring
- [x] Test cases
  • Loading branch information
yangzq50 authored Mar 11, 2024
1 parent 24907d0 commit de951a1
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 40 deletions.
34 changes: 13 additions & 21 deletions src/storage/invertedindex/search/and_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module;

#include <vector>
module and_iterator;

import stl;
Expand All @@ -28,34 +29,25 @@ AndIterator::AndIterator(Vector<UniquePtr<DocIterator>> iterators) {
sorted_iterators_.push_back(children_[i].get());
}
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(), [](const auto lhs, const auto rhs) { return lhs->GetDF() < rhs->GetDF(); });
// initialize doc_id_ to first doc
DoSeek(0);
}

AndIterator::~AndIterator() {}

void AndIterator::DoSeek(docid_t doc_id) {
DocIterator **first_iter = &sorted_iterators_[0];
DocIterator **current_iter = first_iter;
DocIterator **end_iter = first_iter + sorted_iterators_.size();
docid_t current = doc_id;
do {
docid_t tmp_id = INVALID_DOCID;
tmp_id = doc_id_;
if (tmp_id == INVALID_DOCID) {
current = tmp_id;
break;
} else if (tmp_id != current) {
current = tmp_id;
current_iter = first_iter;
auto ib = sorted_iterators_.begin(), ie = sorted_iterators_.end();
while (ib != ie) {
(*ib)->Seek(doc_id);
if (docid_t doc = (*ib)->Doc(); doc != doc_id) {
// not match, restart from the first iterator, since first iterator has fewer docs
doc_id = doc;
ib = sorted_iterators_.begin();
} else {
current_iter++;
if (current_iter >= end_iter) {
current++;
current_iter = first_iter;
}
++ib;
}

} while (true);
doc_id_ = current;
}
doc_id_ = doc_id;
}

u32 AndIterator::GetDF() const {
Expand Down
34 changes: 24 additions & 10 deletions src/storage/invertedindex/search/and_not_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,35 @@ import doc_iterator;

namespace infinity {

AndNotIterator::AndNotIterator(Vector<UniquePtr<DocIterator>> iterators) { children_ = std::move(iterators); }
AndNotIterator::AndNotIterator(Vector<UniquePtr<DocIterator>> iterators) {
children_ = std::move(iterators);
// initialize doc_id_ to first valid doc
DoSeek(0);
}

AndNotIterator::~AndNotIterator() {}

void AndNotIterator::DoSeek(docid_t doc_id) {
if (!children_[0]->Seek(doc_id)) {
// not match in positive child
return;
}
for (u32 i = 1; i < children_.size(); ++i) {
if (children_[i]->Seek(doc_id)) {
// match in negative child
return;
bool next_loop = false;
do {
children_[0]->Seek(doc_id);
if (docid_t doc = children_[0]->Doc(); doc != doc_id) {
doc_id = doc;
}
if (doc_id == INVALID_DOCID) {
break;
}
// now doc_id < INVALID_DOCID
next_loop = false;
for (u32 i = 1; i < children_.size(); ++i) {
children_[i]->Seek(doc_id);
if (docid_t doc = children_[i]->Doc(); doc == doc_id) {
++doc_id;
next_loop = true;
break;
}
}
}
} while (next_loop);
doc_id_ = doc_id;
}

Expand Down
25 changes: 18 additions & 7 deletions src/storage/invertedindex/search/or_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,30 @@ import index_defines;
import multi_query_iterator;
import doc_iterator;
namespace infinity {
OrIterator::OrIterator(Vector<UniquePtr<DocIterator>> iterators) { children_ = std::move(iterators); }
OrIterator::OrIterator(Vector<UniquePtr<DocIterator>> iterators) {
children_ = std::move(iterators);
count_ = children_.size();
iterator_heap_.resize(children_.size() + 1);
for (u32 i = 0; i < children_.size(); ++i) {
iterator_heap_[i + 1].doc_id_ = children_[i]->Doc();
iterator_heap_[i + 1].entry_id_ = i;
}
// Build the heap
for (u32 i = children_.size() / 2; i > 0; --i) {
AdjustDown(i);
}
doc_id_ = iterator_heap_[1].doc_id_;
}

OrIterator::~OrIterator() {}

void OrIterator::DoSeek(docid_t id) {
docid_t doc_id = INVALID_DOCID;
do {
while (id > iterator_heap_[1].doc_id_) {
DocIterator *top = GetDocIterator(iterator_heap_[1].entry_id_);
top->Seek(id);
doc_id = top->Doc();
iterator_heap_[1].doc_id_ = doc_id;
AdjustDown();
} while (id > iterator_heap_[1].doc_id_);
iterator_heap_[1].doc_id_ = top->Doc();
AdjustDown(1);
}
doc_id_ = iterator_heap_[1].doc_id_;
}

Expand Down
3 changes: 1 addition & 2 deletions src/storage/invertedindex/search/or_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ public:
private:
DocIterator *GetDocIterator(u32 i) { return children_[i].get(); }

void AdjustDown() {
void AdjustDown(u32 idx) {
DocIteratorEntry *heap = iterator_heap_.data();
u32 idx = 1;
u32 min = idx;
do {
idx = min;
Expand Down
205 changes: 205 additions & 0 deletions src/unit_test/storage/invertedindex/search/iterator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "unit_test/base_test.h"
#include <random>
import stl;
import index_defines;
import default_values;
import doc_iterator;
import and_iterator;
import or_iterator;
import and_not_iterator;

using namespace infinity;

class MockVectorDocIterator : public DocIterator {
private:
void AddIterator(DocIterator *iter) override{};

u32 GetDF() const override { return 0; }

public:
MockVectorDocIterator(Vector<docid_t> doc_ids) : doc_ids_(std::move(doc_ids)) { DoSeek(0); }

~MockVectorDocIterator() override = default;

void DoSeek(docid_t doc_id) override {
while (idx_ < doc_ids_.size() and doc_ids_[idx_] < doc_id) {
++idx_;
}
doc_id_ = idx_ < doc_ids_.size() ? doc_ids_[idx_] : INVALID_DOCID;
}

Vector<docid_t> doc_ids_;
u32 idx_ = 0;
};

// doc id: 0-100'000
// output length: in range [0, param_len]
auto get_random_doc_ids = [](std::mt19937 &rng, u32 param_len) -> Vector<docid_t> {
// generate random doc ids
Vector<docid_t> doc_ids;
// random size
u32 size = std::uniform_int_distribution<u32>(0, param_len)(rng);
std::uniform_int_distribution<docid_t> gen_id(0, 100'000);
for (u32 i = 0; i < size; ++i) {
doc_ids.push_back(gen_id(rng));
}
// sort and unique
std::sort(doc_ids.begin(), doc_ids.end());
doc_ids.erase(std::unique(doc_ids.begin(), doc_ids.end()), doc_ids.end());
return doc_ids;
};

class SearchIteratorTest2 : public BaseTest {
public:
Vector<docid_t> doc_ids_A, doc_ids_B, doc_ids_and, doc_ids_or, doc_ids_and_not;

SearchIteratorTest2() {}

void SetUp() override {
// prepare random seed
std::random_device rd;
std::mt19937 rng(rd());
doc_ids_A = get_random_doc_ids(rng, 10'000);
doc_ids_B = get_random_doc_ids(rng, 10'000);
doc_ids_and.clear();
doc_ids_or.clear();
doc_ids_and_not.clear();
std::set_intersection(doc_ids_A.begin(), doc_ids_A.end(), doc_ids_B.begin(), doc_ids_B.end(), std::back_inserter(doc_ids_and));
std::set_union(doc_ids_A.begin(), doc_ids_A.end(), doc_ids_B.begin(), doc_ids_B.end(), std::back_inserter(doc_ids_or));
std::set_difference(doc_ids_A.begin(), doc_ids_A.end(), doc_ids_B.begin(), doc_ids_B.end(), std::back_inserter(doc_ids_and_not));
}

void TearDown() override {}
};

TEST_F(SearchIteratorTest2, test_and) {
Vector<UniquePtr<DocIterator>> iterators(2);
iterators[0] = MakeUnique<MockVectorDocIterator>(doc_ids_A);
iterators[1] = MakeUnique<MockVectorDocIterator>(doc_ids_B);
AndIterator and_it(std::move(iterators));
MockVectorDocIterator expect_res(doc_ids_and);
for (docid_t doc_id = 0; doc_id <= 100'000; ++doc_id) {
and_it.Seek(doc_id);
expect_res.Seek(doc_id);
EXPECT_EQ(and_it.Doc(), expect_res.Doc());
}
}

TEST_F(SearchIteratorTest2, test_or) {
Vector<UniquePtr<DocIterator>> iterators(2);
iterators[0] = MakeUnique<MockVectorDocIterator>(doc_ids_A);
iterators[1] = MakeUnique<MockVectorDocIterator>(doc_ids_B);
OrIterator or_it(std::move(iterators));
MockVectorDocIterator expect_res(doc_ids_or);
for (docid_t doc_id = 0; doc_id <= 100'000; ++doc_id) {
or_it.Seek(doc_id);
expect_res.Seek(doc_id);
EXPECT_EQ(or_it.Doc(), expect_res.Doc());
}
}

TEST_F(SearchIteratorTest2, test_and_not) {
Vector<UniquePtr<DocIterator>> iterators(2);
iterators[0] = MakeUnique<MockVectorDocIterator>(doc_ids_A);
iterators[1] = MakeUnique<MockVectorDocIterator>(doc_ids_B);
AndNotIterator and_not_it(std::move(iterators));
MockVectorDocIterator expect_res(doc_ids_and_not);
for (docid_t doc_id = 0; doc_id <= 100'000; ++doc_id) {
and_not_it.Seek(doc_id);
expect_res.Seek(doc_id);
EXPECT_EQ(and_not_it.Doc(), expect_res.Doc());
}
}

#define TestN 5

class SearchIteratorTestN : public BaseTest {
public:
Vector<docid_t> doc_ids[TestN], doc_ids_and, doc_ids_or, doc_ids_and_not;

SearchIteratorTestN() {}

void SetUp() override {
// prepare random seed
std::random_device rd;
std::mt19937 rng(rd());
for (int i = 0; i < TestN; ++i) {
doc_ids[i] = get_random_doc_ids(rng, 30'000);
}
// calculate and, or, and_not
doc_ids_and = doc_ids[0];
doc_ids_or = doc_ids[0];
doc_ids_and_not = doc_ids[0];
for (int i = 1; i < TestN; ++i) {
Vector<docid_t> new_and, new_or, new_and_not;
std::set_intersection(doc_ids_and.begin(), doc_ids_and.end(), doc_ids[i].begin(), doc_ids[i].end(), std::back_inserter(new_and));
std::set_union(doc_ids_or.begin(), doc_ids_or.end(), doc_ids[i].begin(), doc_ids[i].end(), std::back_inserter(new_or));
std::set_difference(doc_ids_and_not.begin(),
doc_ids_and_not.end(),
doc_ids[i].begin(),
doc_ids[i].end(),
std::back_inserter(new_and_not));
doc_ids_and = std::move(new_and);
doc_ids_or = std::move(new_or);
doc_ids_and_not = std::move(new_and_not);
}
}

void TearDown() override {}
};

TEST_F(SearchIteratorTestN, test_and) {
Vector<UniquePtr<DocIterator>> iterators(TestN);
for (int i = 0; i < TestN; ++i) {
iterators[i] = MakeUnique<MockVectorDocIterator>(doc_ids[i]);
}
AndIterator and_it(std::move(iterators));
MockVectorDocIterator expect_res(doc_ids_and);
for (docid_t doc_id = 0; doc_id <= 100'000; ++doc_id) {
and_it.Seek(doc_id);
expect_res.Seek(doc_id);
EXPECT_EQ(and_it.Doc(), expect_res.Doc());
}
}

TEST_F(SearchIteratorTestN, test_or) {
Vector<UniquePtr<DocIterator>> iterators(TestN);
for (int i = 0; i < TestN; ++i) {
iterators[i] = MakeUnique<MockVectorDocIterator>(doc_ids[i]);
}
OrIterator or_it(std::move(iterators));
MockVectorDocIterator expect_res(doc_ids_or);
for (docid_t doc_id = 0; doc_id <= 100'000; ++doc_id) {
or_it.Seek(doc_id);
expect_res.Seek(doc_id);
EXPECT_EQ(or_it.Doc(), expect_res.Doc());
}
}

TEST_F(SearchIteratorTestN, test_and_not) {
Vector<UniquePtr<DocIterator>> iterators(TestN);
for (int i = 0; i < TestN; ++i) {
iterators[i] = MakeUnique<MockVectorDocIterator>(doc_ids[i]);
}
AndNotIterator and_not_it(std::move(iterators));
MockVectorDocIterator expect_res(doc_ids_and_not);
for (docid_t doc_id = 0; doc_id <= 100'000; ++doc_id) {
and_not_it.Seek(doc_id);
expect_res.Seek(doc_id);
EXPECT_EQ(and_not_it.Doc(), expect_res.Doc());
}
}

0 comments on commit de951a1

Please sign in to comment.