-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support path semantic in non-recursive-path
- Loading branch information
wyj
committed
Dec 18, 2024
1 parent
970f5c4
commit e163d9a
Showing
14 changed files
with
244 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#pragma once | ||
|
||
#include "binder/binder.h" | ||
#include "binder/expression/expression.h" | ||
#include "logical_operator_visitor.h" | ||
#include "planner/operator/logical_plan.h" | ||
namespace kuzu { | ||
namespace main { | ||
class ClientContext; | ||
} | ||
namespace catalog { | ||
class Catalog; | ||
} | ||
namespace binder {} | ||
namespace optimizer { | ||
|
||
class PathSemanticRewriter : public LogicalOperatorVisitor { | ||
public: | ||
explicit PathSemanticRewriter(main::ClientContext* context) | ||
: hasReplace(false), context{context} {} | ||
void rewrite(planner::LogicalPlan* plan); | ||
|
||
private: | ||
void visitOperator(const std::shared_ptr<planner::LogicalOperator>& op, | ||
const std::shared_ptr<planner::LogicalOperator>& parent, int index); | ||
std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace( | ||
std::shared_ptr<planner::LogicalOperator> op) override; | ||
std::shared_ptr<planner::LogicalOperator> visitCrossProductReplace( | ||
std::shared_ptr<planner::LogicalOperator> op); | ||
|
||
private: | ||
bool hasReplace; | ||
main::ClientContext* context; | ||
binder::expression_vector scanExpression; | ||
std::shared_ptr<planner::LogicalOperator> appendPathSemanticFilter( | ||
const std::shared_ptr<planner::LogicalOperator> op); | ||
}; | ||
|
||
} // namespace optimizer | ||
} // namespace kuzu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#include "optimizer/path_semantic_rewriter.h" | ||
|
||
#include "binder/expression/path_expression.h" | ||
#include "binder/expression/scalar_function_expression.h" | ||
#include "binder/expression_visitor.h" | ||
#include "catalog/catalog.h" | ||
#include "common/exception/internal.h" | ||
#include "function/built_in_function_utils.h" | ||
#include "function/path/vector_path_functions.h" | ||
#include "function/scalar_function.h" | ||
#include "main/client_context.h" | ||
#include "planner/operator/logical_filter.h" | ||
using namespace kuzu::common; | ||
using namespace kuzu::planner; | ||
|
||
namespace kuzu { | ||
namespace optimizer { | ||
|
||
void PathSemanticRewriter::rewrite(planner::LogicalPlan* plan) { | ||
auto root = plan->getLastOperator(); | ||
visitOperator(root, nullptr, 0); | ||
} | ||
|
||
void PathSemanticRewriter::visitOperator(const std::shared_ptr<planner::LogicalOperator>& op, | ||
const std::shared_ptr<planner::LogicalOperator>& parent, int index) { | ||
|
||
auto result = op; | ||
switch (op->getOperatorType()) { | ||
case planner::LogicalOperatorType::HASH_JOIN: | ||
result = visitHashJoinReplace(op); | ||
break; | ||
case planner::LogicalOperatorType::CROSS_PRODUCT: | ||
result = visitCrossProductReplace(op); | ||
break; | ||
default: | ||
break; | ||
} | ||
|
||
if (hasReplace && parent != nullptr) { | ||
parent->setChild(index, result); | ||
|
||
} else { | ||
for (auto i = 0u; i < op->getNumChildren(); ++i) { | ||
visitOperator(op->getChild(i), op, i); | ||
} | ||
} | ||
} | ||
|
||
std::string semanticSwitch(const common::PathSemantic& semantic) { | ||
switch (semantic) { | ||
case common::PathSemantic::TRAIL: | ||
return function::IsTrailFunction::name; | ||
case common::PathSemantic::ACYCLIC: | ||
return function::IsACyclicFunction::name; | ||
default: | ||
return std::string(); | ||
} | ||
} | ||
|
||
std::shared_ptr<LogicalOperator> PathSemanticRewriter::visitHashJoinReplace( | ||
std::shared_ptr<LogicalOperator> op) { | ||
return appendPathSemanticFilter(op); | ||
} | ||
|
||
std::shared_ptr<LogicalOperator> PathSemanticRewriter::appendPathSemanticFilter( | ||
const std::shared_ptr<LogicalOperator> op) { | ||
// get path expression from binder | ||
auto pathExpressions = context->getBinder()->findPathExpressionInScope(); | ||
auto catalog = context->getCatalog(); | ||
auto transaction = context->getTx(); | ||
auto semanticFunctionName = | ||
semanticSwitch(context->getClientConfig()->recursivePatternSemantic); | ||
if (semanticFunctionName.empty()) { | ||
return op; | ||
} | ||
|
||
auto resultOp = op; | ||
// append is_trail or is_acyclic function filter | ||
for (auto& expr : pathExpressions) { | ||
|
||
std::vector<LogicalType> childrenTypes; | ||
childrenTypes.push_back(expr->getDataType().copy()); | ||
|
||
auto bindExpr = binder::expression_vector{expr}; | ||
auto functions = catalog->getFunctions(transaction); | ||
auto function = function::BuiltInFunctionsUtils::matchFunction(transaction, | ||
semanticFunctionName, childrenTypes, functions) | ||
->ptrCast<function::ScalarFunction>() | ||
->copy(); | ||
|
||
std::unique_ptr<function::FunctionBindData> bindData; | ||
{ | ||
if (function.bindFunc) { | ||
bindData = function.bindFunc({bindExpr, &function, context}); | ||
} else { | ||
bindData = std::make_unique<function::FunctionBindData>( | ||
LogicalType(function.returnTypeID)); | ||
} | ||
} | ||
|
||
auto uniqueExpressionName = | ||
binder::ScalarFunctionExpression::getUniqueName(function.name, bindExpr); | ||
auto filterExpression = | ||
std::make_shared<binder::ScalarFunctionExpression>(ExpressionType::FUNCTION, | ||
std::move(function), std::move(bindData), bindExpr, uniqueExpressionName); | ||
auto printInfo = std::make_unique<OPPrintInfo>(); | ||
|
||
auto filter = std::make_shared<LogicalFilter>( | ||
std::static_pointer_cast<binder::Expression>(filterExpression), resultOp, | ||
std::move(printInfo)); | ||
hasReplace = true; | ||
filter->computeFlatSchema(); | ||
resultOp = filter; | ||
} | ||
return resultOp; | ||
} | ||
std::shared_ptr<planner::LogicalOperator> PathSemanticRewriter::visitCrossProductReplace( | ||
std::shared_ptr<planner::LogicalOperator> op) { | ||
return appendPathSemanticFilter(op); | ||
} | ||
|
||
} // namespace optimizer | ||
} // namespace kuzu |
Oops, something went wrong.