Skip to content

Commit

Permalink
support path semantic in non-recursive-path
Browse files Browse the repository at this point in the history
  • Loading branch information
wyj committed Dec 18, 2024
1 parent 970f5c4 commit e163d9a
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 13 deletions.
9 changes: 7 additions & 2 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ QueryGraph Binder::bindPatternElement(const PatternElement& patternElement) {
nodeAndRels.push_back(rightNode);
leftNode = rightNode;
}
if (patternElement.hasPathName()) {
auto pathName = patternElement.getPathName();

if (patternElement.hasPathName() ||
this->clientContext->getClientConfig()->recursivePatternSemantic != PathSemantic::WALK) {
// query may not explicitly define a path name, but it requires a path for semantic filter;
// therefore, an implicitly defined path is added internally.
auto pathName =
patternElement.hasPathName() ? patternElement.getPathName() : getInternalPathName();
auto pathExpression = createPath(pathName, nodeAndRels);
addToScope(pathName, pathExpression);
}
Expand Down
14 changes: 14 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ std::string Binder::getUniqueExpressionName(const std::string& name) {
return "_" + std::to_string(lastExpressionId++) + "_" + name;
}

std::string Binder::getInternalPathName() {
return InternalKeyword::INTERNAL_PATH + std::to_string(lastInternalPathId++);
}

struct ReservedNames {
// Column name that might conflict with internal names.
static std::unordered_set<std::string> getColumnNames() {
Expand Down Expand Up @@ -267,5 +271,15 @@ function::TableFunction Binder::getScanFunction(FileTypeInfo typeInfo, const Rea
return *func->ptrCast<function::TableFunction>();
}

expression_vector Binder::findPathExpressionInScope() {
expression_vector result;
for (auto expr : this->scope.getExpressions()) {
if (expr->expressionType == ExpressionType::PATH) {
result.push_back(expr);
}
}
return result;
}

} // namespace binder
} // namespace kuzu
20 changes: 18 additions & 2 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/visitor/property_collector.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/property_expression.h"
#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
Expand Down Expand Up @@ -51,8 +52,23 @@ void PropertyCollector::visitQueryPartSkipNodeRel(const NormalizedQueryPart& que

void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
auto& matchClause = readingClause.constCast<BoundMatchClause>();
if (matchClause.hasPredicate()) {
collectProperties(matchClause.getPredicate());
if (recursivePatternSemantic == PathSemantic::WALK) {
if (matchClause.hasPredicate()) {
collectProperties(matchClause.getPredicate());
}
} else {
for (auto node : matchClause.getQueryGraphCollection()->getQueryNodes()) {
for (auto prop : node->getPropertyExprs()) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(node->getPropertyExpression(name));
}
}
for (auto rel : matchClause.getQueryGraphCollection()->getQueryRels()) {
for (auto prop : rel->getPropertyExprs()) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(rel->getPropertyExpression(name));
}
}
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Binder {

public:
explicit Binder(main::ClientContext* clientContext)
: lastExpressionId{0}, scope{}, expressionBinder{this, clientContext},
clientContext{clientContext} {}
: lastInternalPathId{0}, lastExpressionId{0}, scope{},
expressionBinder{this, clientContext}, clientContext{clientContext} {}

std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);

Expand Down Expand Up @@ -291,6 +291,7 @@ class Binder {
void validateDropSequence(const parser::Statement& dropTable);
/*** helpers ***/
std::string getUniqueExpressionName(const std::string& name);
std::string getInternalPathName();

static bool reservedInColumnName(const std::string& name);
static bool reservedInPropertyLookup(const std::string& name);
Expand All @@ -305,7 +306,10 @@ class Binder {

ExpressionBinder* getExpressionBinder() { return &expressionBinder; }

expression_vector findPathExpressionInScope();

private:
uint32_t lastInternalPathId;
common::idx_t lastExpressionId;
BinderScope scope;
ExpressionBinder expressionBinder;
Expand Down
9 changes: 9 additions & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "binder/bound_statement_visitor.h"
#include "common/enums/path_semantic.h"
#include "main/client_config.h"

namespace kuzu {
namespace binder {
Expand All @@ -14,6 +16,11 @@ class PropertyCollector final : public BoundStatementVisitor {
// See with_clause_projection_rewriter for more details.
void visitSingleQuerySkipNodeRel(const NormalizedSingleQuery& singleQuery);

inline void setRecursivePatternSemantic(common::PathSemantic semantic) {
recursivePatternSemantic = semantic;
}
inline common::PathSemantic getRecursivePatternSemantic() { return recursivePatternSemantic; }

private:
void visitQueryPartSkipNodeRel(const NormalizedQueryPart& queryPart);

Expand All @@ -36,6 +43,8 @@ class PropertyCollector final : public BoundStatementVisitor {

private:
expression_set properties;
common::PathSemantic recursivePatternSemantic =
main::ClientConfigDefault::RECURSIVE_PATTERN_SEMANTIC;
};

} // namespace binder
Expand Down
1 change: 1 addition & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct InternalKeyword {
static constexpr char PLACE_HOLDER[] = "_PLACE_HOLDER";
static constexpr char MAP_KEY[] = "KEY";
static constexpr char MAP_VALUE[] = "VALUE";
static constexpr char INTERNAL_PATH[] = "_INTERNAL_PATH";

static constexpr std::string_view ROW_OFFSET = "_row_offset";
static constexpr std::string_view SRC_OFFSET = "_src_offset";
Expand Down
4 changes: 4 additions & 0 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class KUZU_API ClientContext {
common::VirtualFileSystem* getVFSUnsafe() const;
common::RandomEngine* getRandomEngine();

// binder
binder::Binder* getBinder() const;

// Query.
std::unique_ptr<PreparedStatement> prepare(std::string_view query);
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand Down Expand Up @@ -208,6 +211,7 @@ class KUZU_API ClientContext {
// Graph entries
std::unique_ptr<graph::GraphEntrySet> graphEntrySet;
std::mutex mtx;
std::unique_ptr<binder::Binder> binder;
};

} // namespace main
Expand Down
40 changes: 40 additions & 0 deletions src/include/optimizer/path_semantic_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include "binder/binder.h"
#include "binder/expression/expression.h"
#include "logical_operator_visitor.h"
#include "planner/operator/logical_plan.h"
namespace kuzu {
namespace main {
class ClientContext;
}
namespace catalog {
class Catalog;
}
namespace binder {}
namespace optimizer {

class PathSemanticRewriter : public LogicalOperatorVisitor {
public:
explicit PathSemanticRewriter(main::ClientContext* context)
: hasReplace(false), context{context} {}
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(const std::shared_ptr<planner::LogicalOperator>& op,
const std::shared_ptr<planner::LogicalOperator>& parent, int index);
std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) override;
std::shared_ptr<planner::LogicalOperator> visitCrossProductReplace(
std::shared_ptr<planner::LogicalOperator> op);

private:
bool hasReplace;
main::ClientContext* context;
binder::expression_vector scanExpression;
std::shared_ptr<planner::LogicalOperator> appendPathSemanticFilter(
const std::shared_ptr<planner::LogicalOperator> op);
};

} // namespace optimizer
} // namespace kuzu
11 changes: 7 additions & 4 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
}
}
// binding
auto binder = Binder(this);
binder = std::make_unique<binder::Binder>((this));
if (inputParams) {
binder.setInputParameters(*inputParams);
binder->setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
auto boundStatement = binder->bind(*parsedStatement);
preparedStatement->parameterMap = binder->getParameterMap();
preparedStatement->statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
Expand Down Expand Up @@ -581,6 +581,9 @@ processor::WarningContext& ClientContext::getWarningContextUnsafe() {
const processor::WarningContext& ClientContext::getWarningContext() const {
return warningContext;
}
binder::Binder* ClientContext::getBinder() const {
return binder.get();
}

graph::GraphEntrySet& ClientContext::getGraphEntrySetUnsafe() {
return *graphEntrySet;
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_library(kuzu_optimizer
remove_factorization_rewriter.cpp
remove_unnecessary_join_optimizer.cpp
top_k_optimizer.cpp
path_semantic_rewriter.cpp
limit_push_down_optimizer.cpp)

set(ALL_OBJECT_FILES
Expand Down
6 changes: 6 additions & 0 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "optimizer/factorization_rewriter.h"
#include "optimizer/filter_push_down_optimizer.h"
#include "optimizer/limit_push_down_optimizer.h"
#include "optimizer/path_semantic_rewriter.h"
#include "optimizer/projection_push_down_optimizer.h"
#include "optimizer/remove_factorization_rewriter.h"
#include "optimizer/remove_unnecessary_join_optimizer.h"
Expand All @@ -29,6 +30,11 @@ void Optimizer::optimize(planner::LogicalPlan* plan, main::ClientContext* contex
auto correlatedSubqueryUnnestSolver = CorrelatedSubqueryUnnestSolver(nullptr);
correlatedSubqueryUnnestSolver.solve(plan->getLastOperator().get());

auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
removeUnnecessaryJoinOptimizer.rewrite(plan);
auto pathSemanticRewriter = PathSemanticRewriter(context);
pathSemanticRewriter.rewrite(plan);

auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
removeUnnecessaryJoinOptimizer.rewrite(plan);

Expand Down
123 changes: 123 additions & 0 deletions src/optimizer/path_semantic_rewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "optimizer/path_semantic_rewriter.h"

#include "binder/expression/path_expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_visitor.h"
#include "catalog/catalog.h"
#include "common/exception/internal.h"
#include "function/built_in_function_utils.h"
#include "function/path/vector_path_functions.h"
#include "function/scalar_function.h"
#include "main/client_context.h"
#include "planner/operator/logical_filter.h"
using namespace kuzu::common;
using namespace kuzu::planner;

namespace kuzu {
namespace optimizer {

void PathSemanticRewriter::rewrite(planner::LogicalPlan* plan) {
auto root = plan->getLastOperator();
visitOperator(root, nullptr, 0);
}

void PathSemanticRewriter::visitOperator(const std::shared_ptr<planner::LogicalOperator>& op,
const std::shared_ptr<planner::LogicalOperator>& parent, int index) {

auto result = op;
switch (op->getOperatorType()) {
case planner::LogicalOperatorType::HASH_JOIN:
result = visitHashJoinReplace(op);
break;
case planner::LogicalOperatorType::CROSS_PRODUCT:
result = visitCrossProductReplace(op);
break;
default:
break;
}

if (hasReplace && parent != nullptr) {
parent->setChild(index, result);

} else {
for (auto i = 0u; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i), op, i);
}
}
}

std::string semanticSwitch(const common::PathSemantic& semantic) {
switch (semantic) {
case common::PathSemantic::TRAIL:
return function::IsTrailFunction::name;
case common::PathSemantic::ACYCLIC:
return function::IsACyclicFunction::name;
default:
return std::string();
}
}

std::shared_ptr<LogicalOperator> PathSemanticRewriter::visitHashJoinReplace(
std::shared_ptr<LogicalOperator> op) {
return appendPathSemanticFilter(op);
}

std::shared_ptr<LogicalOperator> PathSemanticRewriter::appendPathSemanticFilter(
const std::shared_ptr<LogicalOperator> op) {
// get path expression from binder
auto pathExpressions = context->getBinder()->findPathExpressionInScope();
auto catalog = context->getCatalog();
auto transaction = context->getTx();
auto semanticFunctionName =
semanticSwitch(context->getClientConfig()->recursivePatternSemantic);
if (semanticFunctionName.empty()) {
return op;
}

auto resultOp = op;
// append is_trail or is_acyclic function filter
for (auto& expr : pathExpressions) {

std::vector<LogicalType> childrenTypes;
childrenTypes.push_back(expr->getDataType().copy());

auto bindExpr = binder::expression_vector{expr};
auto functions = catalog->getFunctions(transaction);
auto function = function::BuiltInFunctionsUtils::matchFunction(transaction,
semanticFunctionName, childrenTypes, functions)
->ptrCast<function::ScalarFunction>()
->copy();

std::unique_ptr<function::FunctionBindData> bindData;
{
if (function.bindFunc) {
bindData = function.bindFunc({bindExpr, &function, context});
} else {
bindData = std::make_unique<function::FunctionBindData>(
LogicalType(function.returnTypeID));
}
}

auto uniqueExpressionName =
binder::ScalarFunctionExpression::getUniqueName(function.name, bindExpr);
auto filterExpression =
std::make_shared<binder::ScalarFunctionExpression>(ExpressionType::FUNCTION,
std::move(function), std::move(bindData), bindExpr, uniqueExpressionName);
auto printInfo = std::make_unique<OPPrintInfo>();

auto filter = std::make_shared<LogicalFilter>(
std::static_pointer_cast<binder::Expression>(filterExpression), resultOp,
std::move(printInfo));
hasReplace = true;
filter->computeFlatSchema();
resultOp = filter;
}
return resultOp;
}
std::shared_ptr<planner::LogicalOperator> PathSemanticRewriter::visitCrossProductReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return appendPathSemanticFilter(op);
}

} // namespace optimizer
} // namespace kuzu
Loading

0 comments on commit e163d9a

Please sign in to comment.