diff --git a/src/binder/visitor/property_collector.cpp b/src/binder/visitor/property_collector.cpp index b8630677bd4..22bb9a345b9 100644 --- a/src/binder/visitor/property_collector.cpp +++ b/src/binder/visitor/property_collector.cpp @@ -52,21 +52,27 @@ void PropertyCollector::visitQueryPartSkipNodeRel(const NormalizedQueryPart& que void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) { auto& matchClause = readingClause.constCast(); - if (recursivePatternSemantic == PathSemantic::WALK) { - if (matchClause.hasPredicate()) { - collectProperties(matchClause.getPredicate()); - } - } else { + if (matchClause.hasPredicate()) { + collectProperties(matchClause.getPredicate()); + } + if (recursivePatternSemantic != PathSemantic::WALK) { for (auto node : matchClause.getQueryGraphCollection()->getQueryNodes()) { for (auto prop : node->getPropertyExprs()) { - auto name = prop->constCast().getPropertyName(); - collectProperties(node->getPropertyExpression(name)); + if (prop->constCast().getDataType().getLogicalTypeID() == + LogicalTypeID::INTERNAL_ID) { + auto name = prop->constCast().getPropertyName(); + collectProperties(node->getPropertyExpression(name)); + } } } + for (auto rel : matchClause.getQueryGraphCollection()->getQueryRels()) { for (auto prop : rel->getPropertyExprs()) { - auto name = prop->constCast().getPropertyName(); - collectProperties(rel->getPropertyExpression(name)); + if (prop->constCast().getDataType().getLogicalTypeID() == + LogicalTypeID::INTERNAL_ID) { + auto name = prop->constCast().getPropertyName(); + collectProperties(rel->getPropertyExpression(name)); + } } } } diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index d9c7558cd17..39b596344ae 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -308,6 +308,8 @@ class Binder { expression_vector findPathExpressionInScope(); + const BinderScope& getBinderScope() { return scope; } + private: uint32_t lastInternalPathId; common::idx_t lastExpressionId; diff --git a/src/include/optimizer/path_semantic_rewriter.h b/src/include/optimizer/path_semantic_rewriter.h index d5dbcdebbe4..d05a02d1e1d 100644 --- a/src/include/optimizer/path_semantic_rewriter.h +++ b/src/include/optimizer/path_semantic_rewriter.h @@ -16,8 +16,7 @@ namespace optimizer { class PathSemanticRewriter : public LogicalOperatorVisitor { public: - explicit PathSemanticRewriter(main::ClientContext* context) - : hasReplace(false), context{context} {} + explicit PathSemanticRewriter(main::ClientContext* context) : context{context} {} void rewrite(planner::LogicalPlan* plan); private: @@ -29,11 +28,15 @@ class PathSemanticRewriter : public LogicalOperatorVisitor { std::shared_ptr op); private: - bool hasReplace; + bool hasRecursive = false; + std::shared_ptr topOp = nullptr; + int replaceIndex = -1; main::ClientContext* context; binder::expression_vector scanExpression; std::shared_ptr appendPathSemanticFilter( const std::shared_ptr op); + std::shared_ptr createNode(const std::shared_ptr& expr); + std::shared_ptr createRel(const std::shared_ptr& expr); }; } // namespace optimizer diff --git a/src/optimizer/path_semantic_rewriter.cpp b/src/optimizer/path_semantic_rewriter.cpp index b011bd3b0e1..5cc3d772c86 100644 --- a/src/optimizer/path_semantic_rewriter.cpp +++ b/src/optimizer/path_semantic_rewriter.cpp @@ -1,6 +1,8 @@ #include "optimizer/path_semantic_rewriter.h" +#include "binder/expression/expression.h" #include "binder/expression/path_expression.h" +#include "binder/expression/property_expression.h" #include "binder/expression/scalar_function_expression.h" #include "binder/expression_visitor.h" #include "catalog/catalog.h" @@ -9,7 +11,9 @@ #include "function/path/vector_path_functions.h" #include "function/scalar_function.h" #include "main/client_context.h" +#include "planner/operator/extend/logical_extend.h" #include "planner/operator/logical_filter.h" +#include "planner/operator/logical_hash_join.h" using namespace kuzu::common; using namespace kuzu::planner; @@ -19,30 +23,36 @@ namespace optimizer { void PathSemanticRewriter::rewrite(planner::LogicalPlan* plan) { auto root = plan->getLastOperator(); visitOperator(root, nullptr, 0); + if (hasRecursive) { + topOp->setChild(replaceIndex, appendPathSemanticFilter(topOp->getChild(replaceIndex))); + } } void PathSemanticRewriter::visitOperator(const std::shared_ptr& op, const std::shared_ptr& parent, int index) { - + for (auto i = 0u; i < op->getNumChildren(); ++i) { + visitOperator(op->getChild(i), op, i); + } auto result = op; switch (op->getOperatorType()) { case planner::LogicalOperatorType::HASH_JOIN: result = visitHashJoinReplace(op); + if (hasRecursive) { + topOp = parent; + replaceIndex = index; + } break; case planner::LogicalOperatorType::CROSS_PRODUCT: - result = visitCrossProductReplace(op); + if (hasRecursive) { + topOp = parent; + replaceIndex = index; + } break; default: break; } - - if (hasReplace && parent != nullptr) { + if (parent != nullptr) { parent->setChild(index, result); - - } else { - for (auto i = 0u; i < op->getNumChildren(); ++i) { - visitOperator(op->getChild(i), op, i); - } } } @@ -57,9 +67,156 @@ std::string semanticSwitch(const common::PathSemantic& semantic) { } } +bool checkPattern(const common::PathSemantic& semantic, int nodeCount, int relCount) { + switch (semantic) { + case common::PathSemantic::TRAIL: + return relCount > 1; + case common::PathSemantic::ACYCLIC: + return nodeCount > 1; + default: + return false; + } +} + std::shared_ptr PathSemanticRewriter::visitHashJoinReplace( std::shared_ptr op) { - return appendPathSemanticFilter(op); + auto hashJoin = (LogicalHashJoin*)op.get(); + auto schema = hashJoin->getSchema(); + auto exprs = schema->getExpressionsInScope(); + + for (auto expr : exprs) { + if (expr->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL) { + hasRecursive = true; + return op; + } + } + binder::expression_vector patterns; + int nodeCount = 0, relCount = 0; + std::unordered_set nameSet; + for (auto expr : exprs) { + if (expr->expressionType == ExpressionType::PROPERTY) { + auto rawName = expr->constCast().getRawVariableName(); + if (nameSet.contains(rawName)) { + continue; + } + nameSet.insert(rawName); + auto scopeExprVector = context->getBinder()->getBinderScope().getExpressions(); + for (auto scopeExpr : scopeExprVector) { + if (scopeExpr->toString() == rawName) { + if (scopeExpr->dataType.getLogicalTypeID() == LogicalTypeID::NODE && + context->getClientConfig()->recursivePatternSemantic == + common::PathSemantic::ACYCLIC) { + std::shared_ptr queryNode = createNode(scopeExpr); + patterns.push_back(queryNode); + nodeCount++; + } else if (scopeExpr->dataType.getLogicalTypeID() == LogicalTypeID::REL && + context->getClientConfig()->recursivePatternSemantic == + common::PathSemantic::TRAIL) { + std::shared_ptr queryRel = createRel(scopeExpr); + patterns.push_back(queryRel); + relCount++; + } + } + } + } + } + if (!patterns.empty()) { + auto semanticFunctionName = + semanticSwitch(context->getClientConfig()->recursivePatternSemantic); + if (semanticFunctionName.empty() || + !checkPattern(context->getClientConfig()->recursivePatternSemantic, nodeCount, + relCount)) { + return op; + } + auto pathName = context->getBinder()->getInternalPathName(); + auto pathExpression = context->getBinder()->createPath(pathName, patterns); + auto catalog = context->getCatalog(); + auto transaction = context->getTx(); + auto functions = catalog->getFunctions(transaction); + + auto resultOp = op; + + for (auto i = 0u; i < patterns.size() - 1; ++i) { + std::shared_ptr left = nullptr, right = nullptr; + if (context->getClientConfig()->recursivePatternSemantic == + common::PathSemantic::TRAIL) { + left = patterns[i]->constCast().getInternalIDProperty(); + right = patterns[i + 1]->constCast().getInternalIDProperty(); + } else { + left = patterns[i]->constCast().getInternalID(); + right = patterns[i + 1]->constCast().getInternalID(); + } + binder::expression_vector children; + children.push_back(left); + children.push_back(right); + auto noEquals = context->getBinder()->getExpressionBinder()->bindComparisonExpression( + kuzu::common::ExpressionType::NOT_EQUALS, children); + auto filter = std::make_shared( + std::static_pointer_cast(noEquals), resultOp); + filter->computeFlatSchema(); + resultOp = filter; + } + return resultOp; + } + return op; + // return appendPathSemanticFilter(op); +} +std::shared_ptr PathSemanticRewriter::createRel( + const std::shared_ptr& expr) { + auto& relExpr = expr->constCast(); + std::vector relTableEntries(relExpr.getEntries()); + auto queryRel = make_shared(LogicalType(LogicalTypeID::REL), + relExpr.getUniqueName(), relExpr.getVariableName(), relTableEntries, relExpr.getSrcNode(), + relExpr.getDstNode(), relExpr.getDirectionType(), QueryRelType::NON_RECURSIVE); + queryRel->setAlias(relExpr.getVariableName()); + queryRel->setLabelExpression(relExpr.getLabelExpression()); + + queryRel->addPropertyExpression(InternalKeyword::ID, relExpr.getInternalIDProperty()->copy()); + std::vector fields; + fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID()); + fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID()); + // Bind internal expressions. + fields.emplace_back(InternalKeyword::LABEL, + queryRel->getLabelExpression()->getDataType().copy()); + // Bind properties. + for (auto& expression : queryRel->getPropertyExprsRef()) { + auto& prop = expression->constCast(); + if (prop.isInternalID()) { + fields.emplace_back(prop.getPropertyName(), prop.getDataType().copy()); + } + } + auto extraInfo = std::make_unique(std::move(fields)); + queryRel->setExtraTypeInfo(std::move(extraInfo)); + return queryRel; +} +std::shared_ptr PathSemanticRewriter::createNode( + const std::shared_ptr& expr) { + auto& nodeExpr = expr->constCast(); + std::vector nodeEntries(nodeExpr.getEntries()); + auto queryNode = make_shared(LogicalType(LogicalTypeID::NODE), + nodeExpr.getUniqueName(), nodeExpr.getVariableName(), nodeEntries); + + queryNode->setInternalID(nodeExpr.getInternalID()->copy()); + queryNode->setLabelExpression(nodeExpr.getLabelExpression()); + queryNode->addPropertyExpression(nodeExpr.getInternalID()->toString(), + nodeExpr.getInternalID()->copy()); + queryNode->setAlias(nodeExpr.getVariableName()); + std::vector fieldNames; + std::vector fieldTypes; + fieldNames.emplace_back(InternalKeyword::ID); + fieldNames.emplace_back(InternalKeyword::LABEL); + fieldTypes.push_back(queryNode->getInternalID()->getDataType().copy()); + fieldTypes.push_back(queryNode->getLabelExpression()->getDataType().copy()); + for (auto& expression : queryNode->getPropertyExprsRef()) { + auto prop = ku_dynamic_cast(expression.get()); + if (prop->isInternalID()) { + fieldNames.emplace_back(prop->getPropertyName()); + fieldTypes.emplace_back(prop->dataType.copy()); + } + } + auto extraInfo = std::make_unique(fieldNames, fieldTypes); + queryNode->setExtraTypeInfo(std::move(extraInfo)); + return queryNode; } std::shared_ptr PathSemanticRewriter::appendPathSemanticFilter( @@ -77,11 +234,35 @@ std::shared_ptr PathSemanticRewriter::appendPathSemanticFilter( auto resultOp = op; // append is_trail or is_acyclic function filter for (auto& expr : pathExpressions) { + auto& pathExpr = expr->constCast(); + + binder::expression_vector patterns; + int nodeCount = 0, relCount = 0, recursiveCount = 0; + for (auto child : pathExpr.getChildren()) { + if (child->dataType.getLogicalTypeID() == LogicalTypeID::NODE) { + std::shared_ptr queryNode = createNode(child); + patterns.push_back(queryNode); + nodeCount++; + } else if (child->dataType.getLogicalTypeID() == LogicalTypeID::REL) { + std::shared_ptr queryRel = createRel(child); + patterns.push_back(queryRel); + relCount++; + } else if (child->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL) { + patterns.push_back(child); + recursiveCount++; + } + } + if (relCount == 0) { + // only recursiveRel + continue; + } + auto pathName = context->getBinder()->getInternalPathName(); + auto newPathExpr = context->getBinder()->createPath(pathName, patterns); std::vector childrenTypes; - childrenTypes.push_back(expr->getDataType().copy()); + childrenTypes.push_back(newPathExpr->getDataType().copy()); - auto bindExpr = binder::expression_vector{expr}; + auto bindExpr = binder::expression_vector{newPathExpr}; auto functions = catalog->getFunctions(transaction); auto function = function::BuiltInFunctionsUtils::matchFunction(transaction, semanticFunctionName, childrenTypes, functions) @@ -90,25 +271,22 @@ std::shared_ptr PathSemanticRewriter::appendPathSemanticFilter( std::unique_ptr bindData; { - if (function.bindFunc) { - bindData = function.bindFunc({bindExpr, &function, context}); + if (function->bindFunc) { + bindData = function->bindFunc({bindExpr, function.get(), context}); } else { bindData = std::make_unique( - LogicalType(function.returnTypeID)); + LogicalType(function->returnTypeID)); } } auto uniqueExpressionName = - binder::ScalarFunctionExpression::getUniqueName(function.name, bindExpr); + binder::ScalarFunctionExpression::getUniqueName(function->name, bindExpr); auto filterExpression = std::make_shared(ExpressionType::FUNCTION, std::move(function), std::move(bindData), bindExpr, uniqueExpressionName); - auto printInfo = std::make_unique(); auto filter = std::make_shared( - std::static_pointer_cast(filterExpression), resultOp, - std::move(printInfo)); - hasReplace = true; + std::static_pointer_cast(filterExpression), resultOp); filter->computeFlatSchema(); resultOp = filter; }