Skip to content

Commit

Permalink
Implement get_keys function in fts (#4647)
Browse files Browse the repository at this point in the history
* Implement get-keys function

---------

Co-authored-by: CI Bot <[email protected]>
  • Loading branch information
acquamarin and acquamarin authored Dec 18, 2024
1 parent 109012e commit c46eec1
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 9 deletions.
3 changes: 2 additions & 1 deletion extension/fts/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ add_subdirectory(third_party/snowball)

target_link_libraries(fts_extension
PRIVATE
snowball)
snowball
re2)

set_extension_properties(fts_extension fts fts)

Expand Down
2 changes: 2 additions & 0 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "catalog/catalog_entry/catalog_entry_type.h"
#include "function/create_fts_index.h"
#include "function/drop_fts_index.h"
#include "function/get_keys.h"
#include "function/query_fts_gds.h"
#include "function/query_fts_index.h"
#include "function/stem.h"
Expand All @@ -15,6 +16,7 @@ namespace fts_extension {
void FTSExtension::load(main::ClientContext* context) {
auto& db = *context->getDatabase();
ADD_SCALAR_FUNC(StemFunction);
ADD_SCALAR_FUNC(GetKeysFunction);
ADD_GDS_FUNC(QFTSFunction);
db.addStandaloneCallFunction(CreateFTSFunction::name, CreateFTSFunction::getFunctionSet());
db.addTableFunction(QueryFTSFunction::name, QueryFTSFunction::getFunctionSet());
Expand Down
3 changes: 2 additions & 1 deletion extension/fts/src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(kuzu_fts_function
query_fts_index.cpp
drop_fts_index.cpp
query_fts_gds.cpp
fts_utils.cpp)
fts_utils.cpp
get_keys.cpp)

set(FTS_OBJECT_FILES
${FTS_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_fts_function>
Expand Down
60 changes: 60 additions & 0 deletions extension/fts/src/function/get_keys.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "function/get_keys.h"

#include "common/exception/binder.h"
#include "common/string_utils.h"
#include "function/scalar_function.h"
#include "function/stem.h"
#include "function/string/functions/lower_function.h"
#include "libstemmer.h"
#include "re2.h"

namespace kuzu {
namespace fts_extension {

using namespace function;
using namespace common;

struct GetKeys {
static void operation(common::ku_string_t& word, common::ku_string_t& stemmer,
common::list_entry_t& result, common::ValueVector& resultVector);
};

void GetKeys::operation(common::ku_string_t& query, common::ku_string_t& stemmer,
common::list_entry_t& result, common::ValueVector& resultVector) {
std::string regexPattern = "[0-9!@#$%^&*()_+={}\\[\\]:;<>,.?~\\/\\|'\"`-]+";
std::string replacePattern = " ";
std::string resultStr = query.getAsString();
RE2::GlobalReplace(&resultStr, regexPattern, replacePattern);
StringUtils::toLower(resultStr);
auto words = StringUtils::split(resultStr, " ");
StemFunction::validateStemmer(stemmer.getAsString());
auto sbStemmer = sb_stemmer_new(reinterpret_cast<const char*>(stemmer.getData()), "UTF_8");
result = ListVector::addList(&resultVector, words.size());
for (auto i = 0u; i < result.size; i++) {
auto& word = words[i];
auto stemData = sb_stemmer_stem(sbStemmer, reinterpret_cast<const sb_symbol*>(word.c_str()),
word.length());
ListVector::getDataVector(&resultVector)
->setValue(result.offset + i,
std::string_view{reinterpret_cast<const char*>(stemData)});
}
sb_stemmer_delete(sbStemmer);
}

static std::unique_ptr<FunctionBindData> bindFunc(ScalarBindFuncInput input) {
return FunctionBindData::getSimpleBindData(input.arguments,
LogicalType::LIST(LogicalType::STRING()));
}

function::function_set GetKeysFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
LogicalTypeID::LIST,
ScalarFunction::BinaryStringExecFunction<ku_string_t, ku_string_t, list_entry_t, GetKeys>,
bindFunc));
return result;
}

} // namespace fts_extension
} // namespace kuzu
10 changes: 4 additions & 6 deletions extension/fts/src/function/query_fts_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,25 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output)
runQuery(clientContext, query);
// Compute score
query = common::stringFormat(
"UNWIND tokenize('{}') AS tk "
"WITH collect(stem(tk, '{}')) AS keywords "
"MATCH (a:`{}`) "
"WHERE list_contains(keywords, a.term) "
"WHERE list_contains(get_keys('{}', '{}'), a.term) "
"CALL QFTS(PK, a, {}, {}, cast({} as UINT64), {}, cast({} as UINT64), {}, '{}') "
"RETURN _node AS p, score",
actualQuery, bindData.entry.getFTSConfig().stemmer, bindData.getTermsTableName(),
bindData.getTermsTableName(), actualQuery, bindData.entry.getFTSConfig().stemmer,
bindData.config.k, bindData.config.b, numDocs, avgDocLen, numTermsInQuery,
bindData.config.isConjunctive ? "true" : "false", bindData.tableName);
localState->result = runQuery(clientContext, query);
// Remove project graph
query = stringFormat("CALL drop_project_graph('PK')");
runQuery(clientContext, query);
}
if (localState->numRowsOutput >= localState->result->getNumTuples()) {
if (localState->numRowsOutput >= localState->result->getTable()->getNumTuples()) {
return 0;
}
auto resultTable = localState->result->getTable();
resultTable->scan(output.vectors, localState->numRowsOutput, 1 /* numRowsToScan */);
localState->numRowsOutput++;
return 1;
return output.dataChunk.state->getSelSize();
}

std::unique_ptr<TableFuncLocalState> initLocalState(
Expand Down
15 changes: 15 additions & 0 deletions extension/fts/src/include/function/get_keys.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "function/function.h"

namespace kuzu {
namespace fts_extension {

struct GetKeysFunction {
static constexpr const char* name = "GET_KEYS";

static function::function_set getFunctionSet();
};

} // namespace fts_extension
} // namespace kuzu
16 changes: 16 additions & 0 deletions extension/fts/test/test_files/get_keys.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-DATASET CSV empty

--

-CASE get_keys_function
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT RETURN get_keys('alice bob carol', 'english')
---- 1
[alic,bob,carol]
-STATEMENT RETURN get_keys('alice bo\'b dan', 'english')
---- 1
[alic,bo,b,dan]
-STATEMENT RETURN get_keys('Uwaterloo ut', 'english')
---- 1
[uwaterloo,ut]
2 changes: 1 addition & 1 deletion src/planner/plan/plan_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void Planner::planGDSCall(const BoundReadingClause& readingClause,
gdsCall->computeFactorizedSchema();
probePlan.setLastOperator(gdsCall);
if (gdsCall->constPtrCast<LogicalGDSCall>()->getInfo().func.name == "QFTS") {
auto op = plan->getLastOperator()->getChild(0)->getChild(0)->getChild(1);
auto op = plan->getLastOperator()->getChild(0);
auto prop =
bindData->getNodeInput()->constCast<NodeExpression>().getPropertyExpression(
"df");
Expand Down

0 comments on commit c46eec1

Please sign in to comment.