diff --git a/extension/fts/CMakeLists.txt b/extension/fts/CMakeLists.txt index 75fa2ebb11c..6a2343d02a4 100644 --- a/extension/fts/CMakeLists.txt +++ b/extension/fts/CMakeLists.txt @@ -14,7 +14,8 @@ add_subdirectory(third_party/snowball) target_link_libraries(fts_extension PRIVATE - snowball) + snowball + re2) set_extension_properties(fts_extension fts fts) diff --git a/extension/fts/src/fts_extension.cpp b/extension/fts/src/fts_extension.cpp index 7b75aa078e2..e10bda1607e 100644 --- a/extension/fts/src/fts_extension.cpp +++ b/extension/fts/src/fts_extension.cpp @@ -3,6 +3,7 @@ #include "catalog/catalog_entry/catalog_entry_type.h" #include "function/create_fts_index.h" #include "function/drop_fts_index.h" +#include "function/get_keys.h" #include "function/query_fts_gds.h" #include "function/query_fts_index.h" #include "function/stem.h" @@ -15,6 +16,7 @@ namespace fts_extension { void FTSExtension::load(main::ClientContext* context) { auto& db = *context->getDatabase(); ADD_SCALAR_FUNC(StemFunction); + ADD_SCALAR_FUNC(GetKeysFunction); ADD_GDS_FUNC(QFTSFunction); db.addStandaloneCallFunction(CreateFTSFunction::name, CreateFTSFunction::getFunctionSet()); db.addTableFunction(QueryFTSFunction::name, QueryFTSFunction::getFunctionSet()); diff --git a/extension/fts/src/function/CMakeLists.txt b/extension/fts/src/function/CMakeLists.txt index f6bb9c9c9f6..93f5d2f5327 100644 --- a/extension/fts/src/function/CMakeLists.txt +++ b/extension/fts/src/function/CMakeLists.txt @@ -6,7 +6,8 @@ add_library(kuzu_fts_function query_fts_index.cpp drop_fts_index.cpp query_fts_gds.cpp - fts_utils.cpp) + fts_utils.cpp + get_keys.cpp) set(FTS_OBJECT_FILES ${FTS_OBJECT_FILES} $ diff --git a/extension/fts/src/function/get_keys.cpp b/extension/fts/src/function/get_keys.cpp new file mode 100644 index 00000000000..bc9370b3532 --- /dev/null +++ b/extension/fts/src/function/get_keys.cpp @@ -0,0 +1,60 @@ +#include "function/get_keys.h" + +#include "common/exception/binder.h" +#include "common/string_utils.h" +#include "function/scalar_function.h" +#include "function/stem.h" +#include "function/string/functions/lower_function.h" +#include "libstemmer.h" +#include "re2.h" + +namespace kuzu { +namespace fts_extension { + +using namespace function; +using namespace common; + +struct GetKeys { + static void operation(common::ku_string_t& word, common::ku_string_t& stemmer, + common::list_entry_t& result, common::ValueVector& resultVector); +}; + +void GetKeys::operation(common::ku_string_t& query, common::ku_string_t& stemmer, + common::list_entry_t& result, common::ValueVector& resultVector) { + std::string regexPattern = "[0-9!@#$%^&*()_+={}\\[\\]:;<>,.?~\\/\\|'\"`-]+"; + std::string replacePattern = " "; + std::string resultStr = query.getAsString(); + RE2::GlobalReplace(&resultStr, regexPattern, replacePattern); + StringUtils::toLower(resultStr); + auto words = StringUtils::split(resultStr, " "); + StemFunction::validateStemmer(stemmer.getAsString()); + auto sbStemmer = sb_stemmer_new(reinterpret_cast(stemmer.getData()), "UTF_8"); + result = ListVector::addList(&resultVector, words.size()); + for (auto i = 0u; i < result.size; i++) { + auto& word = words[i]; + auto stemData = sb_stemmer_stem(sbStemmer, reinterpret_cast(word.c_str()), + word.length()); + ListVector::getDataVector(&resultVector) + ->setValue(result.offset + i, + std::string_view{reinterpret_cast(stemData)}); + } + sb_stemmer_delete(sbStemmer); +} + +static std::unique_ptr bindFunc(ScalarBindFuncInput input) { + return FunctionBindData::getSimpleBindData(input.arguments, + LogicalType::LIST(LogicalType::STRING())); +} + +function::function_set GetKeysFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::LIST, + ScalarFunction::BinaryStringExecFunction, + bindFunc)); + return result; +} + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/function/query_fts_index.cpp b/extension/fts/src/function/query_fts_index.cpp index 03922e86c9d..4af6050d2e1 100644 --- a/extension/fts/src/function/query_fts_index.cpp +++ b/extension/fts/src/function/query_fts_index.cpp @@ -115,13 +115,11 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output) runQuery(clientContext, query); // Compute score query = common::stringFormat( - "UNWIND tokenize('{}') AS tk " - "WITH collect(stem(tk, '{}')) AS keywords " "MATCH (a:`{}`) " - "WHERE list_contains(keywords, a.term) " + "WHERE list_contains(get_keys('{}', '{}'), a.term) " "CALL QFTS(PK, a, {}, {}, cast({} as UINT64), {}, cast({} as UINT64), {}, '{}') " "RETURN _node AS p, score", - actualQuery, bindData.entry.getFTSConfig().stemmer, bindData.getTermsTableName(), + bindData.getTermsTableName(), actualQuery, bindData.entry.getFTSConfig().stemmer, bindData.config.k, bindData.config.b, numDocs, avgDocLen, numTermsInQuery, bindData.config.isConjunctive ? "true" : "false", bindData.tableName); localState->result = runQuery(clientContext, query); @@ -129,13 +127,13 @@ static common::offset_t tableFunc(TableFuncInput& data, TableFuncOutput& output) query = stringFormat("CALL drop_project_graph('PK')"); runQuery(clientContext, query); } - if (localState->numRowsOutput >= localState->result->getNumTuples()) { + if (localState->numRowsOutput >= localState->result->getTable()->getNumTuples()) { return 0; } auto resultTable = localState->result->getTable(); resultTable->scan(output.vectors, localState->numRowsOutput, 1 /* numRowsToScan */); localState->numRowsOutput++; - return 1; + return output.dataChunk.state->getSelSize(); } std::unique_ptr initLocalState( diff --git a/extension/fts/src/include/function/get_keys.h b/extension/fts/src/include/function/get_keys.h new file mode 100644 index 00000000000..a13ada19ebf --- /dev/null +++ b/extension/fts/src/include/function/get_keys.h @@ -0,0 +1,15 @@ +#pragma once + +#include "function/function.h" + +namespace kuzu { +namespace fts_extension { + +struct GetKeysFunction { + static constexpr const char* name = "GET_KEYS"; + + static function::function_set getFunctionSet(); +}; + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/test/test_files/get_keys.test b/extension/fts/test/test_files/get_keys.test new file mode 100644 index 00000000000..1f40d18855b --- /dev/null +++ b/extension/fts/test/test_files/get_keys.test @@ -0,0 +1,16 @@ +-DATASET CSV empty + +-- + +-CASE get_keys_function +-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension" +---- ok +-STATEMENT RETURN get_keys('alice bob carol', 'english') +---- 1 +[alic,bob,carol] +-STATEMENT RETURN get_keys('alice bo\'b dan', 'english') +---- 1 +[alic,bo,b,dan] +-STATEMENT RETURN get_keys('Uwaterloo ut', 'english') +---- 1 +[uwaterloo,ut] diff --git a/src/planner/plan/plan_read.cpp b/src/planner/plan/plan_read.cpp index 6c83e4333f6..aadd260e44a 100644 --- a/src/planner/plan/plan_read.cpp +++ b/src/planner/plan/plan_read.cpp @@ -146,7 +146,7 @@ void Planner::planGDSCall(const BoundReadingClause& readingClause, gdsCall->computeFactorizedSchema(); probePlan.setLastOperator(gdsCall); if (gdsCall->constPtrCast()->getInfo().func.name == "QFTS") { - auto op = plan->getLastOperator()->getChild(0)->getChild(0)->getChild(1); + auto op = plan->getLastOperator()->getChild(0); auto prop = bindData->getNodeInput()->constCast().getPropertyExpression( "df");