Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement get_keys function in fts #4647

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion extension/fts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ add_subdirectory(third_party/snowball)

target_link_libraries(fts_extension
PRIVATE
snowball)
snowball
re2)

set_extension_properties(fts_extension fts fts)

Expand Down
2 changes: 2 additions & 0 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "catalog/catalog_entry/catalog_entry_type.h"
#include "function/create_fts_index.h"
#include "function/drop_fts_index.h"
#include "function/get_keys.h"
#include "function/query_fts_gds.h"
#include "function/query_fts_index.h"
#include "function/stem.h"
Expand All @@ -15,6 +16,7 @@ namespace fts_extension {
void FTSExtension::load(main::ClientContext* context) {
auto& db = *context->getDatabase();
ADD_SCALAR_FUNC(StemFunction);
ADD_SCALAR_FUNC(GetKeysFunction);
ADD_GDS_FUNC(QFTSFunction);
db.addStandaloneCallFunction(CreateFTSFunction::name, CreateFTSFunction::getFunctionSet());
db.addTableFunction(QueryFTSFunction::name, QueryFTSFunction::getFunctionSet());
Expand Down
3 changes: 2 additions & 1 deletion extension/fts/src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(kuzu_fts_function
query_fts_index.cpp
drop_fts_index.cpp
query_fts_gds.cpp
fts_utils.cpp)
fts_utils.cpp
get_keys.cpp)

set(FTS_OBJECT_FILES
${FTS_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_fts_function>
Expand Down
60 changes: 60 additions & 0 deletions extension/fts/src/function/get_keys.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "function/get_keys.h"

#include "common/exception/binder.h"
#include "common/string_utils.h"
#include "function/scalar_function.h"
#include "function/stem.h"
#include "function/string/functions/lower_function.h"
#include "libstemmer.h"
#include "re2.h"

namespace kuzu {
namespace fts_extension {

using namespace function;
using namespace common;

struct GetKeys {
static void operation(common::ku_string_t& word, common::ku_string_t& stemmer,
common::list_entry_t& result, common::ValueVector& resultVector);
};

void GetKeys::operation(common::ku_string_t& query, common::ku_string_t& stemmer,
common::list_entry_t& result, common::ValueVector& resultVector) {
std::string regexPattern = "[0-9!@#$%^&*()_+={}\\[\\]:;<>,.?~\\/\\|'\"`-]+";
std::string replacePattern = " ";
std::string resultStr = query.getAsString();
RE2::GlobalReplace(&resultStr, regexPattern, replacePattern);
StringUtils::toLower(resultStr);
auto words = StringUtils::split(resultStr, " ");
StemFunction::validateStemmer(stemmer.getAsString());
auto sbStemmer = sb_stemmer_new(reinterpret_cast<const char*>(stemmer.getData()), "UTF_8");
result = ListVector::addList(&resultVector, words.size());
for (auto i = 0u; i < result.size; i++) {
auto& word = words[i];
auto stemData = sb_stemmer_stem(sbStemmer, reinterpret_cast<const sb_symbol*>(word.c_str()),
word.length());
ListVector::getDataVector(&resultVector)
->setValue(result.offset + i,
std::string_view{reinterpret_cast<const char*>(stemData)});
}
sb_stemmer_delete(sbStemmer);
}

static std::unique_ptr<FunctionBindData> bindFunc(ScalarBindFuncInput input) {
return FunctionBindData::getSimpleBindData(input.arguments,
LogicalType::LIST(LogicalType::STRING()));
}

function::function_set GetKeysFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
LogicalTypeID::LIST,
ScalarFunction::BinaryStringExecFunction<ku_string_t, ku_string_t, list_entry_t, GetKeys>,
bindFunc));
return result;
}

} // namespace fts_extension
} // namespace kuzu
10 changes: 4 additions & 6 deletions extension/fts/src/function/query_fts_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,25 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output)
runQuery(clientContext, query);
// Compute score
query = common::stringFormat(
"UNWIND tokenize('{}') AS tk "
"WITH collect(stem(tk, '{}')) AS keywords "
"MATCH (a:`{}`) "
"WHERE list_contains(keywords, a.term) "
"WHERE list_contains(get_keys('{}', '{}'), a.term) "
"CALL QFTS(PK, a, {}, {}, cast({} as UINT64), {}, cast({} as UINT64), {}, '{}') "
"RETURN _node AS p, score",
actualQuery, bindData.entry.getFTSConfig().stemmer, bindData.getTermsTableName(),
bindData.getTermsTableName(), actualQuery, bindData.entry.getFTSConfig().stemmer,
bindData.config.k, bindData.config.b, numDocs, avgDocLen, numTermsInQuery,
bindData.config.isConjunctive ? "true" : "false", bindData.tableName);
localState->result = runQuery(clientContext, query);
// Remove project graph
query = stringFormat("CALL drop_project_graph('PK')");
runQuery(clientContext, query);
}
if (localState->numRowsOutput >= localState->result->getNumTuples()) {
if (localState->numRowsOutput >= localState->result->getTable()->getNumTuples()) {
return 0;
}
auto resultTable = localState->result->getTable();
resultTable->scan(output.vectors, localState->numRowsOutput, 1 /* numRowsToScan */);
localState->numRowsOutput++;
return 1;
return output.dataChunk.state->getSelSize();
}

std::unique_ptr<TableFuncLocalState> initLocalState(
Expand Down
15 changes: 15 additions & 0 deletions extension/fts/src/include/function/get_keys.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "function/function.h"

namespace kuzu {
namespace fts_extension {

struct GetKeysFunction {
static constexpr const char* name = "GET_KEYS";

static function::function_set getFunctionSet();
};

} // namespace fts_extension
} // namespace kuzu
16 changes: 16 additions & 0 deletions extension/fts/test/test_files/get_keys.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-DATASET CSV empty

--

-CASE get_keys_function
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT RETURN get_keys('alice bob carol', 'english')
---- 1
[alic,bob,carol]
-STATEMENT RETURN get_keys('alice bo\'b dan', 'english')
---- 1
[alic,bo,b,dan]
-STATEMENT RETURN get_keys('Uwaterloo ut', 'english')
---- 1
[uwaterloo,ut]
2 changes: 1 addition & 1 deletion src/planner/plan/plan_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
gdsCall->computeFactorizedSchema();
probePlan.setLastOperator(gdsCall);
if (gdsCall->constPtrCast<LogicalGDSCall>()->getInfo().func.name == "QFTS") {
auto op = plan->getLastOperator()->getChild(0)->getChild(0)->getChild(1);
auto op = plan->getLastOperator()->getChild(0);

Check warning on line 149 in src/planner/plan/plan_read.cpp

View check run for this annotation

Codecov / codecov/patch

src/planner/plan/plan_read.cpp#L149

Added line #L149 was not covered by tests
auto prop =
bindData->getNodeInput()->constCast<NodeExpression>().getPropertyExpression(
"df");
Expand Down
Loading