Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support path semantic in non-recursive-path #4405

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ QueryGraph Binder::bindPatternElement(const PatternElement& patternElement) {
nodeAndRels.push_back(rightNode);
leftNode = rightNode;
}
if (patternElement.hasPathName()) {
auto pathName = patternElement.getPathName();

if (patternElement.hasPathName() ||
this->clientContext->getClientConfig()->recursivePatternSemantic != PathSemantic::WALK) {
// query may not explicitly define a path name, but it requires a path for semantic filter;
// therefore, an implicitly defined path is added internally.
auto pathName =
patternElement.hasPathName() ? patternElement.getPathName() : getInternalPathName();
auto pathExpression = createPath(pathName, nodeAndRels);
addToScope(pathName, pathExpression);
}
Expand Down
14 changes: 14 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ std::string Binder::getUniqueExpressionName(const std::string& name) {
return "_" + std::to_string(lastExpressionId++) + "_" + name;
}

std::string Binder::getInternalPathName() {
return InternalKeyword::INTERNAL_PATH + std::to_string(lastInternalPathId++);
}

struct ReservedNames {
// Column name that might conflict with internal names.
static std::unordered_set<std::string> getColumnNames() {
Expand Down Expand Up @@ -267,5 +271,15 @@ function::TableFunction Binder::getScanFunction(FileTypeInfo typeInfo, const Rea
return *func->ptrCast<function::TableFunction>();
}

expression_vector Binder::findPathExpressionInScope() {
expression_vector result;
for (auto expr : this->scope.getExpressions()) {
if (expr->expressionType == ExpressionType::PATH) {
result.push_back(expr);
}
}
return result;
}

} // namespace binder
} // namespace kuzu
22 changes: 22 additions & 0 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/visitor/property_collector.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/property_expression.h"
#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
Expand Down Expand Up @@ -54,6 +55,27 @@ void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
if (matchClause.hasPredicate()) {
collectProperties(matchClause.getPredicate());
}
if (recursivePatternSemantic != PathSemantic::WALK) {
for (auto node : matchClause.getQueryGraphCollection()->getQueryNodes()) {
for (auto prop : node->getPropertyExprs()) {
if (prop->constCast<PropertyExpression>().getDataType().getLogicalTypeID() ==
LogicalTypeID::INTERNAL_ID) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(node->getPropertyExpression(name));
}
}
}

for (auto rel : matchClause.getQueryGraphCollection()->getQueryRels()) {
for (auto prop : rel->getPropertyExprs()) {
if (prop->constCast<PropertyExpression>().getDataType().getLogicalTypeID() ==
LogicalTypeID::INTERNAL_ID) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(rel->getPropertyExpression(name));
}
}
}
}
}

void PropertyCollector::visitUnwind(const BoundReadingClause& readingClause) {
Expand Down
10 changes: 8 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Binder {

public:
explicit Binder(main::ClientContext* clientContext)
: lastExpressionId{0}, scope{}, expressionBinder{this, clientContext},
clientContext{clientContext} {}
: lastInternalPathId{0}, lastExpressionId{0}, scope{},
expressionBinder{this, clientContext}, clientContext{clientContext} {}

std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);

Expand Down Expand Up @@ -291,6 +291,7 @@ class Binder {
void validateDropSequence(const parser::Statement& dropTable);
/*** helpers ***/
std::string getUniqueExpressionName(const std::string& name);
std::string getInternalPathName();

static bool reservedInColumnName(const std::string& name);
static bool reservedInPropertyLookup(const std::string& name);
Expand All @@ -305,7 +306,12 @@ class Binder {

ExpressionBinder* getExpressionBinder() { return &expressionBinder; }

expression_vector findPathExpressionInScope();

const BinderScope& getBinderScope() { return scope; }

private:
uint32_t lastInternalPathId;
common::idx_t lastExpressionId;
BinderScope scope;
ExpressionBinder expressionBinder;
Expand Down
9 changes: 9 additions & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "binder/bound_statement_visitor.h"
#include "common/enums/path_semantic.h"
#include "main/client_config.h"

namespace kuzu {
namespace binder {
Expand All @@ -14,6 +16,11 @@ class PropertyCollector final : public BoundStatementVisitor {
// See with_clause_projection_rewriter for more details.
void visitSingleQuerySkipNodeRel(const NormalizedSingleQuery& singleQuery);

inline void setRecursivePatternSemantic(common::PathSemantic semantic) {
recursivePatternSemantic = semantic;
}
inline common::PathSemantic getRecursivePatternSemantic() { return recursivePatternSemantic; }

private:
void visitQueryPartSkipNodeRel(const NormalizedQueryPart& queryPart);

Expand All @@ -36,6 +43,8 @@ class PropertyCollector final : public BoundStatementVisitor {

private:
expression_set properties;
common::PathSemantic recursivePatternSemantic =
main::ClientConfigDefault::RECURSIVE_PATTERN_SEMANTIC;
};

} // namespace binder
Expand Down
1 change: 1 addition & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct InternalKeyword {
static constexpr char PLACE_HOLDER[] = "_PLACE_HOLDER";
static constexpr char MAP_KEY[] = "KEY";
static constexpr char MAP_VALUE[] = "VALUE";
static constexpr char INTERNAL_PATH[] = "_INTERNAL_PATH";

static constexpr std::string_view ROW_OFFSET = "_row_offset";
static constexpr std::string_view SRC_OFFSET = "_src_offset";
Expand Down
4 changes: 4 additions & 0 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class KUZU_API ClientContext {
common::VirtualFileSystem* getVFSUnsafe() const;
common::RandomEngine* getRandomEngine();

// binder
binder::Binder* getBinder() const;

// Query.
std::unique_ptr<PreparedStatement> prepare(std::string_view query);
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand Down Expand Up @@ -208,6 +211,7 @@ class KUZU_API ClientContext {
// Graph entries
std::unique_ptr<graph::GraphEntrySet> graphEntrySet;
std::mutex mtx;
std::unique_ptr<binder::Binder> binder;
};

} // namespace main
Expand Down
43 changes: 43 additions & 0 deletions src/include/optimizer/path_semantic_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include "binder/binder.h"
#include "binder/expression/expression.h"
#include "logical_operator_visitor.h"
#include "planner/operator/logical_plan.h"
namespace kuzu {
namespace main {
class ClientContext;
}
namespace catalog {
class Catalog;
}
namespace binder {}
namespace optimizer {

class PathSemanticRewriter : public LogicalOperatorVisitor {
public:
explicit PathSemanticRewriter(main::ClientContext* context) : context{context} {}
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(const std::shared_ptr<planner::LogicalOperator>& op,
const std::shared_ptr<planner::LogicalOperator>& parent, int index);
std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) override;
std::shared_ptr<planner::LogicalOperator> visitCrossProductReplace(
std::shared_ptr<planner::LogicalOperator> op);

private:
bool hasRecursive = false;
std::shared_ptr<planner::LogicalOperator> topOp = nullptr;
int replaceIndex = -1;
main::ClientContext* context;
binder::expression_vector scanExpression;
std::shared_ptr<planner::LogicalOperator> appendPathSemanticFilter(
const std::shared_ptr<planner::LogicalOperator> op);
std::shared_ptr<binder::Expression> createNode(const std::shared_ptr<binder::Expression>& expr);
std::shared_ptr<binder::Expression> createRel(const std::shared_ptr<binder::Expression>& expr);
};

} // namespace optimizer
} // namespace kuzu
11 changes: 7 additions & 4 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
}
}
// binding
auto binder = Binder(this);
binder = std::make_unique<binder::Binder>((this));
if (inputParams) {
binder.setInputParameters(*inputParams);
binder->setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
auto boundStatement = binder->bind(*parsedStatement);
preparedStatement->parameterMap = binder->getParameterMap();
preparedStatement->statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
Expand Down Expand Up @@ -581,6 +581,9 @@ processor::WarningContext& ClientContext::getWarningContextUnsafe() {
const processor::WarningContext& ClientContext::getWarningContext() const {
return warningContext;
}
binder::Binder* ClientContext::getBinder() const {
return binder.get();
}

graph::GraphEntrySet& ClientContext::getGraphEntrySetUnsafe() {
return *graphEntrySet;
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_library(kuzu_optimizer
remove_factorization_rewriter.cpp
remove_unnecessary_join_optimizer.cpp
top_k_optimizer.cpp
path_semantic_rewriter.cpp
limit_push_down_optimizer.cpp)

set(ALL_OBJECT_FILES
Expand Down
6 changes: 6 additions & 0 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "optimizer/factorization_rewriter.h"
#include "optimizer/filter_push_down_optimizer.h"
#include "optimizer/limit_push_down_optimizer.h"
#include "optimizer/path_semantic_rewriter.h"
#include "optimizer/projection_push_down_optimizer.h"
#include "optimizer/remove_factorization_rewriter.h"
#include "optimizer/remove_unnecessary_join_optimizer.h"
Expand All @@ -29,6 +30,11 @@ void Optimizer::optimize(planner::LogicalPlan* plan, main::ClientContext* contex
auto correlatedSubqueryUnnestSolver = CorrelatedSubqueryUnnestSolver(nullptr);
correlatedSubqueryUnnestSolver.solve(plan->getLastOperator().get());

auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
removeUnnecessaryJoinOptimizer.rewrite(plan);
auto pathSemanticRewriter = PathSemanticRewriter(context);
pathSemanticRewriter.rewrite(plan);

auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
removeUnnecessaryJoinOptimizer.rewrite(plan);

Expand Down
Loading