Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Dec 18, 2024
1 parent 1ee84ea commit 01f78b3
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/include/planner/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class Planner {
void appendHashJoin(const std::vector<binder::expression_pair>& joinConditions,
common::JoinType joinType, std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendAccHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
void appendAccHashJoin(const std::vector<binder::expression_pair>& joinConditions, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan);
void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
Expand Down
8 changes: 4 additions & 4 deletions src/planner/plan/append_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ void Planner::appendHashJoin(const std::vector<expression_pair>& joinConditions,
resultPlan.setLastOperator(std::move(hashJoin));
}

void Planner::appendAccHashJoin(const expression_vector& joinNodeIDs, JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
void Planner::appendAccHashJoin(const std::vector<binder::expression_pair>& joinConditions,
JoinType joinType, std::shared_ptr<Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan) {
KU_ASSERT(probePlan.hasUpdate());
tryAppendAccumulate(probePlan);
appendHashJoin(joinNodeIDs, joinType, mark, probePlan, buildPlan, resultPlan);
appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan);
auto& sipInfo = probePlan.getLastOperator()->cast<LogicalHashJoin>().getSIPInfoUnsafe();
sipInfo.direction = SIPDirection::PROBE_TO_BUILD;
}
Expand Down
132 changes: 74 additions & 58 deletions src/planner/plan/plan_subquery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ class UnnestSubqueryAnalyzer {
void analyze() {
for (auto predicate : predicates) {
if (predicate->expressionType != common::ExpressionType::EQUALS) {
unnestAsInnerJoin_ = false;
unnestAsJoin_ = false;
return;
}
if (isJoinCondition(*predicate->getChild(0), *predicate->getChild(1))) {
joinConditions.emplace_back(predicate->getChild(0), predicate->getChild(1));
} else if (isJoinCondition(*predicate->getChild(1), *predicate->getChild(0))) {
joinConditions.emplace_back(predicate->getChild(1), predicate->getChild(0));
} else {
unnestAsInnerJoin_ = false;
unnestAsJoin_ = false;
return;
}
}
if (unnestAsInnerJoin_) {
if (unnestAsJoin_) {
for (auto& node : queryGraphCollection.getQueryNodes()) {
if (schema.isExpressionInScope(*node->getInternalID())) {
joinConditions.emplace_back(node->getInternalID(), node->getInternalID());
Expand All @@ -64,7 +64,7 @@ class UnnestSubqueryAnalyzer {
}
}

bool unnestAsInnerJoin() const { return unnestAsInnerJoin_; }
bool unnestAsJoin() const { return unnestAsJoin_; }
std::vector<binder::expression_pair> getJoinConditions() const {
return joinConditions;
}
Expand All @@ -91,7 +91,7 @@ class UnnestSubqueryAnalyzer {
const QueryGraphCollection& queryGraphCollection;
expression_vector predicates;

bool unnestAsInnerJoin_ = true;
bool unnestAsJoin_ = true;
std::vector<binder::expression_pair> joinConditions;
};

Expand Down Expand Up @@ -119,19 +119,21 @@ void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection
//}

void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection,
const expression_vector& predicates, const binder::expression_vector& corrExprs,
const expression_vector& predicates, const binder::expression_vector& correlatedExprs,
std::shared_ptr<Expression> mark, LogicalPlan& leftPlan) {
auto info = QueryGraphPlanningInfo();
info.predicates = predicates;

if (leftPlan.isEmpty()) {
// Optional match is the first clause. No left plan to join.
// Optional match is the first clause, e.g. OPTIONAL MATCH <pattern> RETURN *
info.predicates = predicates;
auto plan = planQueryGraphCollection(queryGraphCollection, info);
leftPlan.setLastOperator(plan->getLastOperator());
appendOptionalAccumulate(mark, leftPlan);
return;
}
if (corrExprs.empty()) {
// No join condition, apply cross product.
if (correlatedExprs.empty()) {
// Plan uncorrelated subquery (think of this as a CTE)
info.predicates = predicates;
auto rightPlan = planQueryGraphCollection(queryGraphCollection, info);
if (leftPlan.hasUpdate()) {
appendAccOptionalCrossProduct(mark, leftPlan, *rightPlan, leftPlan);
Expand All @@ -140,25 +142,35 @@ void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection
}
return;
}
info.corrExprs = corrExprs;
// Plan correlated subquery
info.corrExprsCard = leftPlan.getCardinality();
auto analyzer = UnnestSubqueryAnalyzer(*leftPlan.getSchema(), queryGraphCollection, predicates);
analyzer.analyze();
std::vector<expression_pair> joinConditions;
std::unique_ptr<LogicalPlan> rightPlan;
if (isInternalIDCorrelated(queryGraphCollection, corrExprs)) {
// If all correlated expressions are node IDs. We can trivially unnest by scanning internal
// ID in both outer and inner plan as these are fast in-memory operations. For node
// properties, we only scan in the outer query.
if (analyzer.unnestAsJoin()) {
// Unnest as vanilla join
info.subqueryType = SubqueryType::INTERNAL_ID_CORRELATED;
rightPlan = planQueryGraphCollectionInNewContext(queryGraphCollection, info);
info.corrExprs = analyzer.getCorrelatedInternalIDs();
rightPlan =
planQueryGraphCollectionInNewContext(queryGraphCollection, info);
joinConditions = analyzer.getJoinConditions();
} else {
// Unnest using ExpressionsScan which scans the accumulated table on probe side.
// Unnest as expression scan + distinct & inner join
info.subqueryType = SubqueryType::CORRELATED;
rightPlan = planQueryGraphCollectionInNewContext(queryGraphCollection, info);
appendAccumulate(corrExprs, leftPlan);
info.corrExprs = correlatedExprs;
info.predicates = predicates;
for (auto& expr : correlatedExprs) {
joinConditions.emplace_back(expr, expr);
}
rightPlan =
planQueryGraphCollectionInNewContext(queryGraphCollection, info);
appendAccumulate(correlatedExprs, leftPlan);
}
if (leftPlan.hasUpdate()) {
appendAccHashJoin(corrExprs, JoinType::LEFT, mark, leftPlan, *rightPlan, leftPlan);
appendAccHashJoin(joinConditions, JoinType::LEFT, mark, leftPlan, *rightPlan, leftPlan);
} else {
appendHashJoin(corrExprs, JoinType::LEFT, mark, leftPlan, *rightPlan, leftPlan);
appendHashJoin(joinConditions, JoinType::LEFT, mark, leftPlan, *rightPlan, leftPlan);
}
}

Expand All @@ -176,7 +188,7 @@ void Planner::planRegularMatch(const QueryGraphCollection& queryGraphCollection,
}
auto correlatedExprs =
getCorrelatedExprs(queryGraphCollection, predicatesToPushDown, leftPlan.getSchema());
auto joinNodeIDs = ExpressionUtil::getExpressionsWithDataType(correlatedExpressions,
auto joinNodeIDs = ExpressionUtil::getExpressionsWithDataType(correlatedExprs,
LogicalTypeID::INTERNAL_ID);
auto info = QueryGraphPlanningInfo();
info.predicates = predicatesToPushDown;
Expand Down Expand Up @@ -215,6 +227,7 @@ void Planner::planSubquery(const std::shared_ptr<Expression>& expression, Logica
std::unique_ptr<LogicalPlan> innerPlan;
auto info = QueryGraphPlanningInfo();
if (correlatedExprs.empty()) {
// Plan uncorrelated subquery
info.subqueryType = SubqueryType::NONE;
info.predicates = predicates;
innerPlan =
Expand All @@ -234,45 +247,48 @@ void Planner::planSubquery(const std::shared_ptr<Expression>& expression, Logica
KU_UNREACHABLE;
}
appendCrossProduct(outerPlan, *innerPlan, outerPlan);
return;
}
// Plan correlated subquery
info.corrExprsCard = outerPlan.getCardinality();
auto analyzer = UnnestSubqueryAnalyzer(*outerPlan.getSchema(), *subquery->getQueryGraphCollection(), predicates);
analyzer.analyze();
std::vector<expression_pair> joinConditions;
if (analyzer.unnestAsJoin()) {
// Unnest as vanilla join
info.subqueryType = SubqueryType::INTERNAL_ID_CORRELATED;
info.corrExprs = analyzer.getCorrelatedInternalIDs();
innerPlan =
planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info);
joinConditions = analyzer.getJoinConditions();
} else {
info.corrExprsCard = outerPlan.getCardinality();
auto analyzer = UnnestSubqueryAnalyzer(*outerPlan.getSchema(), *subquery->getQueryGraphCollection(), predicates);
analyzer.analyze();
std::vector<expression_pair> joinConditions;
if (analyzer.unnestAsInnerJoin()) {
info.subqueryType = SubqueryType::INTERNAL_ID_CORRELATED;
info.corrExprs = analyzer.getCorrelatedInternalIDs();
innerPlan =
planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info);
joinConditions = analyzer.getJoinConditions();
} else { // Unnest
info.subqueryType = SubqueryType::CORRELATED;
info.corrExprs = correlatedExprs;
info.predicates = predicates;
for (auto& expr : correlatedExprs) {
joinConditions.emplace_back(expr, expr);
}
innerPlan =
planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info);
appendAccumulate(correlatedExprs, outerPlan);
// Unnest as expression scan + distinct & inner join
info.subqueryType = SubqueryType::CORRELATED;
info.corrExprs = correlatedExprs;
info.predicates = predicates;
for (auto& expr : correlatedExprs) {
joinConditions.emplace_back(expr, expr);
}
switch (subquery->getSubqueryType()) {
case common::SubqueryType::EXISTS: {
appendMarkJoin(joinConditions, expression, outerPlan, *innerPlan, outerPlan);
} break;
case common::SubqueryType::COUNT: {
expression_vector hashKeys;
for (auto& joinCondition : joinConditions) {
hashKeys.push_back(joinCondition.second);
}
appendAggregate(hashKeys, expression_vector{subquery->getProjectionExpr()},
*innerPlan);
appendHashJoin(joinConditions, common::JoinType::COUNT, nullptr, outerPlan, *innerPlan,
outerPlan);
} break;
default:
KU_UNREACHABLE;
innerPlan =
planQueryGraphCollectionInNewContext(*subquery->getQueryGraphCollection(), info);
appendAccumulate(correlatedExprs, outerPlan);
}
switch (subquery->getSubqueryType()) {
case common::SubqueryType::EXISTS: {
appendMarkJoin(joinConditions, expression, outerPlan, *innerPlan, outerPlan);
} break;
case common::SubqueryType::COUNT: {
expression_vector hashKeys;
for (auto& joinCondition : joinConditions) {
hashKeys.push_back(joinCondition.second);
}
appendAggregate(hashKeys, expression_vector{subquery->getProjectionExpr()},
*innerPlan);
appendHashJoin(joinConditions, common::JoinType::COUNT, nullptr, outerPlan, *innerPlan,
outerPlan);
} break;
default:
KU_UNREACHABLE;
}
}

Expand Down

0 comments on commit 01f78b3

Please sign in to comment.