Skip to content

Commit

Permalink
path semantic use no equal
Browse files Browse the repository at this point in the history
  • Loading branch information
wyj committed Dec 18, 2024
1 parent e163d9a commit aaa2bc4
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 32 deletions.
24 changes: 15 additions & 9 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,27 @@ void PropertyCollector::visitQueryPartSkipNodeRel(const NormalizedQueryPart& que

void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
auto& matchClause = readingClause.constCast<BoundMatchClause>();
if (recursivePatternSemantic == PathSemantic::WALK) {
if (matchClause.hasPredicate()) {
collectProperties(matchClause.getPredicate());
}
} else {
if (matchClause.hasPredicate()) {
collectProperties(matchClause.getPredicate());
}
if (recursivePatternSemantic != PathSemantic::WALK) {
for (auto node : matchClause.getQueryGraphCollection()->getQueryNodes()) {
for (auto prop : node->getPropertyExprs()) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(node->getPropertyExpression(name));
if (prop->constCast<PropertyExpression>().getDataType().getLogicalTypeID() ==
LogicalTypeID::INTERNAL_ID) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(node->getPropertyExpression(name));
}
}
}

for (auto rel : matchClause.getQueryGraphCollection()->getQueryRels()) {
for (auto prop : rel->getPropertyExprs()) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(rel->getPropertyExpression(name));
if (prop->constCast<PropertyExpression>().getDataType().getLogicalTypeID() ==
LogicalTypeID::INTERNAL_ID) {
auto name = prop->constCast<PropertyExpression>().getPropertyName();
collectProperties(rel->getPropertyExpression(name));
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ class Binder {

expression_vector findPathExpressionInScope();

const BinderScope& getBinderScope() { return scope; }

private:
uint32_t lastInternalPathId;
common::idx_t lastExpressionId;
Expand Down
9 changes: 6 additions & 3 deletions src/include/optimizer/path_semantic_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ namespace optimizer {

class PathSemanticRewriter : public LogicalOperatorVisitor {
public:
explicit PathSemanticRewriter(main::ClientContext* context)
: hasReplace(false), context{context} {}
explicit PathSemanticRewriter(main::ClientContext* context) : context{context} {}
void rewrite(planner::LogicalPlan* plan);

private:
Expand All @@ -29,11 +28,15 @@ class PathSemanticRewriter : public LogicalOperatorVisitor {
std::shared_ptr<planner::LogicalOperator> op);

private:
bool hasReplace;
bool hasRecursive = false;
std::shared_ptr<planner::LogicalOperator> topOp = nullptr;
int replaceIndex = -1;
main::ClientContext* context;
binder::expression_vector scanExpression;
std::shared_ptr<planner::LogicalOperator> appendPathSemanticFilter(
const std::shared_ptr<planner::LogicalOperator> op);
std::shared_ptr<binder::Expression> createNode(const std::shared_ptr<binder::Expression>& expr);
std::shared_ptr<binder::Expression> createRel(const std::shared_ptr<binder::Expression>& expr);
};

} // namespace optimizer
Expand Down
218 changes: 198 additions & 20 deletions src/optimizer/path_semantic_rewriter.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "optimizer/path_semantic_rewriter.h"

#include "binder/expression/expression.h"
#include "binder/expression/path_expression.h"
#include "binder/expression/property_expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_visitor.h"
#include "catalog/catalog.h"
Expand All @@ -9,7 +11,9 @@
#include "function/path/vector_path_functions.h"
#include "function/scalar_function.h"
#include "main/client_context.h"
#include "planner/operator/extend/logical_extend.h"
#include "planner/operator/logical_filter.h"
#include "planner/operator/logical_hash_join.h"
using namespace kuzu::common;
using namespace kuzu::planner;

Expand All @@ -19,30 +23,36 @@ namespace optimizer {
void PathSemanticRewriter::rewrite(planner::LogicalPlan* plan) {
auto root = plan->getLastOperator();
visitOperator(root, nullptr, 0);
if (hasRecursive) {
topOp->setChild(replaceIndex, appendPathSemanticFilter(topOp->getChild(replaceIndex)));
}
}

void PathSemanticRewriter::visitOperator(const std::shared_ptr<planner::LogicalOperator>& op,
const std::shared_ptr<planner::LogicalOperator>& parent, int index) {

for (auto i = 0u; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i), op, i);
}
auto result = op;
switch (op->getOperatorType()) {
case planner::LogicalOperatorType::HASH_JOIN:
result = visitHashJoinReplace(op);
if (hasRecursive) {
topOp = parent;
replaceIndex = index;
}
break;
case planner::LogicalOperatorType::CROSS_PRODUCT:
result = visitCrossProductReplace(op);
if (hasRecursive) {
topOp = parent;
replaceIndex = index;
}
break;
default:
break;
}

if (hasReplace && parent != nullptr) {
if (parent != nullptr) {
parent->setChild(index, result);

} else {
for (auto i = 0u; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i), op, i);
}
}
}

Expand All @@ -57,9 +67,156 @@ std::string semanticSwitch(const common::PathSemantic& semantic) {
}
}

bool checkPattern(const common::PathSemantic& semantic, int nodeCount, int relCount) {
switch (semantic) {
case common::PathSemantic::TRAIL:
return relCount > 1;
case common::PathSemantic::ACYCLIC:
return nodeCount > 1;
default:
return false;
}
}

std::shared_ptr<LogicalOperator> PathSemanticRewriter::visitHashJoinReplace(
std::shared_ptr<LogicalOperator> op) {
return appendPathSemanticFilter(op);
auto hashJoin = (LogicalHashJoin*)op.get();
auto schema = hashJoin->getSchema();
auto exprs = schema->getExpressionsInScope();

for (auto expr : exprs) {
if (expr->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL) {
hasRecursive = true;
return op;
}
}
binder::expression_vector patterns;
int nodeCount = 0, relCount = 0;
std::unordered_set<std::string> nameSet;
for (auto expr : exprs) {
if (expr->expressionType == ExpressionType::PROPERTY) {
auto rawName = expr->constCast<binder::PropertyExpression>().getRawVariableName();
if (nameSet.contains(rawName)) {
continue;
}
nameSet.insert(rawName);
auto scopeExprVector = context->getBinder()->getBinderScope().getExpressions();
for (auto scopeExpr : scopeExprVector) {
if (scopeExpr->toString() == rawName) {
if (scopeExpr->dataType.getLogicalTypeID() == LogicalTypeID::NODE &&
context->getClientConfig()->recursivePatternSemantic ==
common::PathSemantic::ACYCLIC) {
std::shared_ptr<binder::Expression> queryNode = createNode(scopeExpr);
patterns.push_back(queryNode);
nodeCount++;
} else if (scopeExpr->dataType.getLogicalTypeID() == LogicalTypeID::REL &&
context->getClientConfig()->recursivePatternSemantic ==
common::PathSemantic::TRAIL) {
std::shared_ptr<binder::Expression> queryRel = createRel(scopeExpr);
patterns.push_back(queryRel);
relCount++;
}
}
}
}
}
if (!patterns.empty()) {
auto semanticFunctionName =
semanticSwitch(context->getClientConfig()->recursivePatternSemantic);
if (semanticFunctionName.empty() ||
!checkPattern(context->getClientConfig()->recursivePatternSemantic, nodeCount,
relCount)) {
return op;
}
auto pathName = context->getBinder()->getInternalPathName();
auto pathExpression = context->getBinder()->createPath(pathName, patterns);
auto catalog = context->getCatalog();
auto transaction = context->getTx();
auto functions = catalog->getFunctions(transaction);

auto resultOp = op;

for (auto i = 0u; i < patterns.size() - 1; ++i) {
std::shared_ptr<binder::Expression> left = nullptr, right = nullptr;
if (context->getClientConfig()->recursivePatternSemantic ==
common::PathSemantic::TRAIL) {
left = patterns[i]->constCast<binder::RelExpression>().getInternalIDProperty();
right = patterns[i + 1]->constCast<binder::RelExpression>().getInternalIDProperty();
} else {
left = patterns[i]->constCast<binder::NodeExpression>().getInternalID();
right = patterns[i + 1]->constCast<binder::NodeExpression>().getInternalID();
}
binder::expression_vector children;
children.push_back(left);
children.push_back(right);
auto noEquals = context->getBinder()->getExpressionBinder()->bindComparisonExpression(
kuzu::common::ExpressionType::NOT_EQUALS, children);
auto filter = std::make_shared<LogicalFilter>(
std::static_pointer_cast<binder::Expression>(noEquals), resultOp);
filter->computeFlatSchema();
resultOp = filter;
}
return resultOp;
}
return op;
// return appendPathSemanticFilter(op);
}
std::shared_ptr<binder::Expression> PathSemanticRewriter::createRel(
const std::shared_ptr<binder::Expression>& expr) {
auto& relExpr = expr->constCast<binder::RelExpression>();
std::vector<catalog::TableCatalogEntry*> relTableEntries(relExpr.getEntries());
auto queryRel = make_shared<binder::RelExpression>(LogicalType(LogicalTypeID::REL),
relExpr.getUniqueName(), relExpr.getVariableName(), relTableEntries, relExpr.getSrcNode(),
relExpr.getDstNode(), relExpr.getDirectionType(), QueryRelType::NON_RECURSIVE);
queryRel->setAlias(relExpr.getVariableName());
queryRel->setLabelExpression(relExpr.getLabelExpression());

queryRel->addPropertyExpression(InternalKeyword::ID, relExpr.getInternalIDProperty()->copy());
std::vector<StructField> fields;
fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID());
fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID());
// Bind internal expressions.
fields.emplace_back(InternalKeyword::LABEL,
queryRel->getLabelExpression()->getDataType().copy());
// Bind properties.
for (auto& expression : queryRel->getPropertyExprsRef()) {
auto& prop = expression->constCast<binder::PropertyExpression>();
if (prop.isInternalID()) {
fields.emplace_back(prop.getPropertyName(), prop.getDataType().copy());
}
}
auto extraInfo = std::make_unique<StructTypeInfo>(std::move(fields));
queryRel->setExtraTypeInfo(std::move(extraInfo));
return queryRel;
}
std::shared_ptr<binder::Expression> PathSemanticRewriter::createNode(
const std::shared_ptr<binder::Expression>& expr) {
auto& nodeExpr = expr->constCast<binder::NodeExpression>();
std::vector<catalog::TableCatalogEntry*> nodeEntries(nodeExpr.getEntries());
auto queryNode = make_shared<binder::NodeExpression>(LogicalType(LogicalTypeID::NODE),
nodeExpr.getUniqueName(), nodeExpr.getVariableName(), nodeEntries);

queryNode->setInternalID(nodeExpr.getInternalID()->copy());
queryNode->setLabelExpression(nodeExpr.getLabelExpression());
queryNode->addPropertyExpression(nodeExpr.getInternalID()->toString(),
nodeExpr.getInternalID()->copy());
queryNode->setAlias(nodeExpr.getVariableName());
std::vector<std::string> fieldNames;
std::vector<LogicalType> fieldTypes;
fieldNames.emplace_back(InternalKeyword::ID);
fieldNames.emplace_back(InternalKeyword::LABEL);
fieldTypes.push_back(queryNode->getInternalID()->getDataType().copy());
fieldTypes.push_back(queryNode->getLabelExpression()->getDataType().copy());
for (auto& expression : queryNode->getPropertyExprsRef()) {
auto prop = ku_dynamic_cast<binder::PropertyExpression*>(expression.get());
if (prop->isInternalID()) {
fieldNames.emplace_back(prop->getPropertyName());
fieldTypes.emplace_back(prop->dataType.copy());
}
}
auto extraInfo = std::make_unique<StructTypeInfo>(fieldNames, fieldTypes);
queryNode->setExtraTypeInfo(std::move(extraInfo));
return queryNode;
}

std::shared_ptr<LogicalOperator> PathSemanticRewriter::appendPathSemanticFilter(
Expand All @@ -77,11 +234,35 @@ std::shared_ptr<LogicalOperator> PathSemanticRewriter::appendPathSemanticFilter(
auto resultOp = op;
// append is_trail or is_acyclic function filter
for (auto& expr : pathExpressions) {
auto& pathExpr = expr->constCast<binder::PathExpression>();

binder::expression_vector patterns;
int nodeCount = 0, relCount = 0, recursiveCount = 0;
for (auto child : pathExpr.getChildren()) {
if (child->dataType.getLogicalTypeID() == LogicalTypeID::NODE) {
std::shared_ptr<binder::Expression> queryNode = createNode(child);
patterns.push_back(queryNode);
nodeCount++;
} else if (child->dataType.getLogicalTypeID() == LogicalTypeID::REL) {
std::shared_ptr<binder::Expression> queryRel = createRel(child);
patterns.push_back(queryRel);
relCount++;
} else if (child->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL) {
patterns.push_back(child);
recursiveCount++;
}
}
if (relCount == 0) {
// only recursiveRel
continue;
}

auto pathName = context->getBinder()->getInternalPathName();
auto newPathExpr = context->getBinder()->createPath(pathName, patterns);
std::vector<LogicalType> childrenTypes;
childrenTypes.push_back(expr->getDataType().copy());
childrenTypes.push_back(newPathExpr->getDataType().copy());

auto bindExpr = binder::expression_vector{expr};
auto bindExpr = binder::expression_vector{newPathExpr};
auto functions = catalog->getFunctions(transaction);
auto function = function::BuiltInFunctionsUtils::matchFunction(transaction,
semanticFunctionName, childrenTypes, functions)
Expand All @@ -90,25 +271,22 @@ std::shared_ptr<LogicalOperator> PathSemanticRewriter::appendPathSemanticFilter(

std::unique_ptr<function::FunctionBindData> bindData;
{
if (function.bindFunc) {
bindData = function.bindFunc({bindExpr, &function, context});
if (function->bindFunc) {
bindData = function->bindFunc({bindExpr, function.get(), context});
} else {
bindData = std::make_unique<function::FunctionBindData>(
LogicalType(function.returnTypeID));
LogicalType(function->returnTypeID));
}
}

auto uniqueExpressionName =
binder::ScalarFunctionExpression::getUniqueName(function.name, bindExpr);
binder::ScalarFunctionExpression::getUniqueName(function->name, bindExpr);
auto filterExpression =
std::make_shared<binder::ScalarFunctionExpression>(ExpressionType::FUNCTION,
std::move(function), std::move(bindData), bindExpr, uniqueExpressionName);
auto printInfo = std::make_unique<OPPrintInfo>();

auto filter = std::make_shared<LogicalFilter>(
std::static_pointer_cast<binder::Expression>(filterExpression), resultOp,
std::move(printInfo));
hasReplace = true;
std::static_pointer_cast<binder::Expression>(filterExpression), resultOp);
filter->computeFlatSchema();
resultOp = filter;
}
Expand Down

0 comments on commit aaa2bc4

Please sign in to comment.