Skip to content

Commit

Permalink
Add ngram analyzer and ut (#400)
Browse files Browse the repository at this point in the history
* Fix ld warning

* Add ut for standard analyzer

* Add ngram analyzer
  • Loading branch information
yingfeng authored Dec 29, 2023
1 parent 39f8b17 commit 6f49f89
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ endif ()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")

MESSAGE(STATUS "C++ Compilation flags: " ${CMAKE_CXX_FLAGS})
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -static-libstdc++ -static-libgcc")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -static-libstdc++ -static-libgcc -z noexecstack")

#add_definitions(-march=native)
if (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL "18.0")
Expand Down
71 changes: 71 additions & 0 deletions src/common/analyzer/ngram_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

#include "string_utils.h"
#include <iostream>
import stl;
import term;
import stemmer;
import analyzer;
import tokenizer;
module ngram_analyzer;

namespace infinity {

bool NGramAnalyzer::NextInString(const char *data,
SizeT length,
SizeT *__restrict pos,
SizeT *__restrict token_start,
SizeT *__restrict token_length) {
*token_start = *pos;
*token_length = 0;
SizeT code_points = 0;
for (; code_points < ngram_ && *token_start + *token_length < length; ++code_points) {
if (std::isspace(data[*token_start + *token_length])) {
*pos += UTF8SeqLength(static_cast<u8>(data[*pos]));
*token_start = *pos;
*token_length = 0;
return true;
}
SizeT sz = UTF8SeqLength(static_cast<u8>(data[*token_start + *token_length]));
*token_length += sz;
}
*pos += UTF8SeqLength(static_cast<u8>(data[*pos]));
return code_points == ngram_;
}

int NGramAnalyzer::AnalyzeImpl(const Term &input, void *data, HookType func) {
unsigned char level = 0;

SizeT len = input.text_.length();
if (len == 0)
return 0;

SizeT cur = 0;
SizeT token_start = 0;
SizeT token_length = 0;
SizeT offset = input.word_offset_;
while (cur < len && NextInString(input.text_.c_str(), len, &cur, &token_start, &token_length)) {
if (token_length == 0)
continue;
func(data, input.text_.c_str() + token_start, token_length, offset, Term::AND, level, false);
offset++;
}

return 1;
}

} // namespace infinity
18 changes: 16 additions & 2 deletions src/common/analyzer/ngram_analyzer.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,21 @@ import term;
import stemmer;
import analyzer;
import tokenizer;
import common_analyzer;
export module ngram_analyzer;

namespace infinity {}
namespace infinity {
export class NGramAnalyzer : public Analyzer {
public:
NGramAnalyzer(u32 ngram) : ngram_(ngram) {}

~NGramAnalyzer() = default;

protected:
int AnalyzeImpl(const Term &input, void *data, HookType func) override;

bool NextInString(const char *data, SizeT length, SizeT *__restrict pos, SizeT *__restrict token_start, SizeT *__restrict token_length);

private:
u32 ngram_;
};
} // namespace infinity
29 changes: 29 additions & 0 deletions src/common/analyzer/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,33 @@ inline std::string ToLower(std::string const &s) {
return result;
}

inline bool IsUTF8Sep(const uint8_t c) { return c < 128 && !std::isalnum(c); }

template <typename T>
inline uint32_t GetLeadingZeroBits(T x) {
if constexpr (sizeof(T) <= sizeof(unsigned int)) {
return __builtin_clz(x);
} else if constexpr (sizeof(T) <= sizeof(unsigned long int)) {
return __builtin_clzl(x);
} else {
return __builtin_clzll(x);
}
}

template <typename T>
inline uint32_t BitScanReverse(T x) {
return (std::max<size_t>(sizeof(T), sizeof(unsigned int))) * 8 - 1 - GetLeadingZeroBits(x);
}

/// return UTF-8 code point sequence length
inline uint32_t UTF8SeqLength(const uint8_t first_octet) {
if (first_octet < 0x80 || first_octet >= 0xF8)
return 1;

const uint32_t bits = 8;
const auto first_zero = BitScanReverse(static_cast<uint8_t>(~first_octet));

return bits - 1 - first_zero;
}

} // namespace infinity
76 changes: 76 additions & 0 deletions src/unit_test/common/analyzer/ngram_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "unit_test/base_test.h"
#include <iostream>
import stl;
import term;
import ngram_analyzer;
import standard_analyzer;

using namespace infinity;

class NGramAnalyzerTest : public BaseTest {};

TEST_F(NGramAnalyzerTest, test1) {
NGramAnalyzer analyzer(2);
TermList term_list;
String input("hello world 123");
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 10U);
ASSERT_EQ(term_list[0].text_, String("he"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("el"));
ASSERT_EQ(term_list[1].word_offset_, 1U);
ASSERT_EQ(term_list[2].text_, String("ll"));
ASSERT_EQ(term_list[2].word_offset_, 2U);
ASSERT_EQ(term_list[3].text_, String("lo"));
ASSERT_EQ(term_list[3].word_offset_, 3U);
ASSERT_EQ(term_list[4].text_, String("wo"));
ASSERT_EQ(term_list[4].word_offset_, 4U);
ASSERT_EQ(term_list[5].text_, String("or"));
ASSERT_EQ(term_list[5].word_offset_, 5U);
ASSERT_EQ(term_list[6].text_, String("rl"));
ASSERT_EQ(term_list[6].word_offset_, 6U);
ASSERT_EQ(term_list[7].text_, String("ld"));
ASSERT_EQ(term_list[7].word_offset_, 7U);
ASSERT_EQ(term_list[8].text_, String("12"));
ASSERT_EQ(term_list[8].word_offset_, 8U);
ASSERT_EQ(term_list[9].text_, String("23"));
ASSERT_EQ(term_list[9].word_offset_, 9U);
}

TEST_F(NGramAnalyzerTest, test2) {
NGramAnalyzer analyzer(1);
TermList term_list;
String input("abc de fg");
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 7U);
ASSERT_EQ(term_list[0].text_, String("a"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("b"));
ASSERT_EQ(term_list[1].word_offset_, 1U);
ASSERT_EQ(term_list[2].text_, String("c"));
ASSERT_EQ(term_list[2].word_offset_, 2U);
ASSERT_EQ(term_list[3].text_, String("d"));
ASSERT_EQ(term_list[3].word_offset_, 3U);
ASSERT_EQ(term_list[4].text_, String("e"));
ASSERT_EQ(term_list[4].word_offset_, 4U);
ASSERT_EQ(term_list[5].text_, String("f"));
ASSERT_EQ(term_list[5].word_offset_, 5U);
ASSERT_EQ(term_list[6].text_, String("g"));
ASSERT_EQ(term_list[6].word_offset_, 6U);
}
114 changes: 114 additions & 0 deletions src/unit_test/common/analyzer/standard_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "unit_test/base_test.h"

import stl;
import term;
import standard_analyzer;
using namespace infinity;

class StandardAnalyzerTest : public BaseTest {};

TEST_F(StandardAnalyzerTest, test1) {
StandardAnalyzer analyzer;
TermList term_list;
String input("Boost unit tests.");
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 3U);
ASSERT_EQ(term_list[0].text_, String("boost"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("unit"));
ASSERT_EQ(term_list[1].word_offset_, 1U);
ASSERT_EQ(term_list[2].text_, String("tests"));
ASSERT_EQ(term_list[2].word_offset_, 2U);
// ASSERT_EQ(term_list[3].text_, PLACE_HOLDER);
// ASSERT_EQ(term_list[3].word_offset_, 3U);
}

TEST_F(StandardAnalyzerTest, test2) {
StandardAnalyzer analyzer;
TermList term_list;
String input("Boost unit tests.");
analyzer.SetCaseSensitive(true, false);
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 3U);
ASSERT_EQ(term_list[0].text_, String("Boost"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("unit"));
ASSERT_EQ(term_list[1].word_offset_, 1U);
ASSERT_EQ(term_list[2].text_, String("tests"));
ASSERT_EQ(term_list[2].word_offset_, 2U);
// ASSERT_EQ(term_list[3].text_, PLACE_HOLDER);
// ASSERT_EQ(term_list[3].word_offset_, 3U);
}

TEST_F(StandardAnalyzerTest, test3) {
StandardAnalyzer analyzer;
TermList term_list;
String input("Boost unit tests.");
analyzer.SetExtractEngStem(true);
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 4U);
ASSERT_EQ(term_list[0].text_, String("boost"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("unit"));
ASSERT_EQ(term_list[1].word_offset_, 1U);
ASSERT_EQ(term_list[2].text_, String("tests"));
ASSERT_EQ(term_list[2].word_offset_, 2U);
ASSERT_EQ(term_list[3].text_, String("test"));
ASSERT_EQ(term_list[3].word_offset_, 2U);
// ASSERT_EQ(term_list[3].text_, PLACE_HOLDER);
// ASSERT_EQ(term_list[3].word_offset_, 3U);
}

TEST_F(StandardAnalyzerTest, test4) {
StandardAnalyzer analyzer;
TermList term_list;
String input("Boost unit tests.");
analyzer.SetCaseSensitive(true, true);
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 4U);
ASSERT_EQ(term_list[0].text_, String("Boost"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("boost"));
ASSERT_EQ(term_list[1].word_offset_, 0U);
ASSERT_EQ(term_list[2].text_, String("unit"));
ASSERT_EQ(term_list[2].word_offset_, 1U);
ASSERT_EQ(term_list[3].text_, String("tests"));
ASSERT_EQ(term_list[3].word_offset_, 2U);
// ASSERT_EQ(term_list[3].text_, PLACE_HOLDER);
// ASSERT_EQ(term_list[3].word_offset_, 3U);
}

TEST_F(StandardAnalyzerTest, test5) {
StandardAnalyzer analyzer;
TermList term_list;
String input("BoostBoostboostBoostboost unit tests.");
analyzer.Analyze(input, term_list);

ASSERT_EQ(term_list.size(), 3U);
ASSERT_EQ(term_list[0].text_, String("boostboostboostboostboost"));
ASSERT_EQ(term_list[0].word_offset_, 0U);
ASSERT_EQ(term_list[1].text_, String("unit"));
ASSERT_EQ(term_list[1].word_offset_, 1U);
ASSERT_EQ(term_list[2].text_, String("tests"));
ASSERT_EQ(term_list[2].word_offset_, 2U);
// ASSERT_EQ(term_list[3].text_, PLACE_HOLDER);
// ASSERT_EQ(term_list[3].word_offset_, 3U);
}

0 comments on commit 6f49f89

Please sign in to comment.