Skip to content

Commit

Permalink
Improve subquery planning (#4651)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU authored Dec 22, 2024
1 parent 9dd8083 commit c93bbd9
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 97 deletions.
8 changes: 5 additions & 3 deletions src/include/common/data_chunk/sel_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class SelectionVector {
enum class State {
DYNAMIC,
STATIC,
STATIC_FILTERED,
};

public:
Expand All @@ -50,9 +49,12 @@ class SelectionVector {
}
void setRange(sel_t startPos, sel_t size) {
KU_ASSERT(startPos + size <= capacity);
selectedPositions = const_cast<sel_t*>(INCREMENTAL_SELECTED_POS.data()) + startPos;
selectedPositions = selectedPositionsBuffer.get();
for (auto i = 0u; i < size; ++i) {
selectedPositions[i] = startPos + i;
}
selectedSize = size;
state = State::STATIC_FILTERED;
state = State::DYNAMIC;
}

// Set to filtered is not very accurate. It sets selectedPositions to a mutable array.
Expand Down
2 changes: 1 addition & 1 deletion src/include/common/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class KUZU_API StringUtils {
}
static std::string_view rtrim(std::string_view input) {
auto end = input.size();
while (end > 0 && isspace(input[end - 1])) {
while (end > 0 && isSpace(input[end - 1])) {
end--;
}
return input.substr(0, end);
Expand Down
4 changes: 4 additions & 0 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ class CostModel {
static uint64_t computeExtendCost(const LogicalPlan& childPlan);
static uint64_t computeRecursiveExtendCost(uint8_t upperBound, double extensionRate,
const LogicalPlan& childPlan);
static uint64_t computeHashJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeIntersectCost(const LogicalPlan& probePlan,
Expand Down
20 changes: 13 additions & 7 deletions src/include/planner/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ namespace planner {

struct LogicalInsertInfo;

enum class SubqueryType : uint8_t {
enum class SubqueryPlanningType : uint8_t {
NONE = 0,
INTERNAL_ID_CORRELATED = 1,
UNNEST_CORRELATED = 1,
CORRELATED = 2,
};

struct QueryGraphPlanningInfo {
// Predicate info.
binder::expression_vector predicates;
// Subquery info.
SubqueryType subqueryType = SubqueryType::NONE;
SubqueryPlanningType subqueryType = SubqueryPlanningType::NONE;
binder::expression_vector corrExprs;
cardinality_t corrExprsCard = 0;
// Join hint info.
Expand Down Expand Up @@ -270,12 +270,18 @@ class Planner {
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan);
void appendAccHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan);
void appendHashJoin(const std::vector<binder::expression_pair>& joinConditions,
common::JoinType joinType, std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendAccHashJoin(const std::vector<binder::expression_pair>& joinConditions,
common::JoinType joinType, std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<binder::Expression>& mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan);
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendMarkJoin(const std::vector<binder::expression_pair>& joinConditions,
const std::shared_ptr<binder::Expression>& mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendIntersect(const std::shared_ptr<binder::Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
Expand Down
22 changes: 22 additions & 0 deletions src/planner/join_order/cost_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ uint64_t CostModel::computeRecursiveExtendCost(uint8_t upperBound, double extens
upperBound;
}

binder::expression_vector getJoinNodeIDs(
const std::vector<binder::expression_pair>& joinConditions) {
binder::expression_vector joinNodeIDs;
for (auto& [left, _] : joinConditions) {
if (left->expressionType == ExpressionType::PROPERTY &&
left->getDataType().getLogicalTypeID() == LogicalTypeID::INTERNAL_ID) {
joinNodeIDs.push_back(left);
}
}
return joinNodeIDs;
}

uint64_t CostModel::computeHashJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build) {
return computeHashJoinCost(getJoinNodeIDs(joinConditions), probe, build);
}

uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build) {
uint64_t cost = 0ul;
Expand All @@ -29,6 +46,11 @@ uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNod
return cost;
}

uint64_t CostModel::computeMarkJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build) {
return computeMarkJoinCost(getJoinNodeIDs(joinConditions), probe, build);
}

uint64_t CostModel::computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build) {
return computeHashJoinCost(joinNodeIDs, probe, build);
Expand Down
29 changes: 21 additions & 8 deletions src/planner/plan/append_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType join
for (auto& joinNodeID : joinNodeIDs) {
joinConditions.emplace_back(joinNodeID, joinNodeID);
}
appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan);
}

void Planner::appendHashJoin(const std::vector<expression_pair>& joinConditions, JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
auto hashJoin = make_shared<LogicalHashJoin>(joinConditions, joinType, mark,
probePlan.getLastOperator(), buildPlan.getLastOperator());
// Apply flattening to probe side
Expand All @@ -38,26 +44,33 @@ void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType join
// Update cost
hashJoin->setCardinality(cardinalityEstimator.estimateHashJoin(joinConditions,
probePlan.getLastOperatorRef(), buildPlan.getLastOperatorRef()));
resultPlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan));
resultPlan.setCost(CostModel::computeHashJoinCost(joinConditions, probePlan, buildPlan));
resultPlan.setLastOperator(std::move(hashJoin));
}

void Planner::appendAccHashJoin(const expression_vector& joinNodeIDs, JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
void Planner::appendAccHashJoin(const std::vector<binder::expression_pair>& joinConditions,
JoinType joinType, std::shared_ptr<Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan) {
KU_ASSERT(probePlan.hasUpdate());
tryAppendAccumulate(probePlan);
appendHashJoin(joinNodeIDs, joinType, mark, probePlan, buildPlan, resultPlan);
appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan);
auto& sipInfo = probePlan.getLastOperator()->cast<LogicalHashJoin>().getSIPInfoUnsafe();
sipInfo.direction = SIPDirection::PROBE_TO_BUILD;
}

void Planner::appendMarkJoin(const expression_vector& joinNodeIDs,
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
std::vector<join_condition_t> joinConditions;
for (auto& joinNodeID : joinNodeIDs) {
joinConditions.emplace_back(joinNodeID, joinNodeID);
}
appendMarkJoin(joinConditions, mark, probePlan, buildPlan, resultPlan);
}

void Planner::appendMarkJoin(const std::vector<expression_pair>& joinConditions,
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
auto hashJoin = make_shared<LogicalHashJoin>(joinConditions, JoinType::MARK, mark,
probePlan.getLastOperator(), buildPlan.getLastOperator());
// Apply flattening to probe side
Expand All @@ -69,8 +82,8 @@ void Planner::appendMarkJoin(const expression_vector& joinNodeIDs,
hashJoin->computeFactorizedSchema();
// update cost. Mark join does not change cardinality.
hashJoin->setCardinality(probePlan.getCardinality());
probePlan.setCost(CostModel::computeMarkJoinCost(joinNodeIDs, probePlan, buildPlan));
probePlan.setLastOperator(std::move(hashJoin));
resultPlan.setCost(CostModel::computeMarkJoinCost(joinConditions, probePlan, buildPlan));
resultPlan.setLastOperator(std::move(hashJoin));
}

void Planner::appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
Expand Down
19 changes: 10 additions & 9 deletions src/planner/plan/plan_join_order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::enumerateQueryGraphCollection
auto& corrExprs = info.corrExprs;
auto corrExprsSet = binder::expression_set{corrExprs.begin(), corrExprs.end()};
int32_t queryGraphIdxToPlanExpressionsScan = -1;
if (info.subqueryType == SubqueryType::CORRELATED) {
if (info.subqueryType == SubqueryPlanningType::CORRELATED) {
// Pick a query graph to plan ExpressionsScan. If -1 is returned, we fall back to cross
// product.
queryGraphIdxToPlanExpressionsScan =
Expand Down Expand Up @@ -80,17 +80,17 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::enumerateQueryGraphCollection
auto newInfo = info;
newInfo.predicates = predicatesToEvaluate;
switch (info.subqueryType) {
case SubqueryType::NONE:
case SubqueryType::INTERNAL_ID_CORRELATED: {
case SubqueryPlanningType::NONE:
case SubqueryPlanningType::UNNEST_CORRELATED: {
plans = enumerateQueryGraph(*queryGraph, newInfo);
} break;
case SubqueryType::CORRELATED: {
case SubqueryPlanningType::CORRELATED: {
if (i == (uint32_t)queryGraphIdxToPlanExpressionsScan) {
// Plan ExpressionsScan with current query graph.
plans = enumerateQueryGraph(*queryGraph, newInfo);
} else {
// Plan current query graph as an isolated query graph.
newInfo.subqueryType = SubqueryType::NONE;
newInfo.subqueryType = SubqueryPlanningType::NONE;
plans = enumerateQueryGraph(*queryGraph, newInfo);
}
} break;
Expand All @@ -101,7 +101,8 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::enumerateQueryGraphCollection
}
// Fail to plan ExpressionsScan with any query graph. Plan it independently and fall back to
// cross product.
if (info.subqueryType == SubqueryType::CORRELATED && queryGraphIdxToPlanExpressionsScan == -1) {
if (info.subqueryType == SubqueryPlanningType::CORRELATED &&
queryGraphIdxToPlanExpressionsScan == -1) {
auto plan = std::make_unique<LogicalPlan>();
appendExpressionsScan(corrExprs, *plan);
appendDistinct(corrExprs, *plan);
Expand Down Expand Up @@ -186,12 +187,12 @@ void Planner::planBaseTableScans(const QueryGraphPlanningInfo& info) {
auto& corrExprs = info.corrExprs;
auto corrExprsSet = expression_set{corrExprs.begin(), corrExprs.end()};
switch (info.subqueryType) {
case SubqueryType::NONE: {
case SubqueryPlanningType::NONE: {
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
planNodeScan(nodePos);
}
} break;
case SubqueryType::INTERNAL_ID_CORRELATED: {
case SubqueryPlanningType::UNNEST_CORRELATED: {
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
auto queryNode = queryGraph->getQueryNode(nodePos);
if (corrExprsSet.contains(queryNode->getInternalID())) {
Expand All @@ -205,7 +206,7 @@ void Planner::planBaseTableScans(const QueryGraphPlanningInfo& info) {
}
}
} break;
case SubqueryType::CORRELATED: {
case SubqueryPlanningType::CORRELATED: {
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
auto queryNode = queryGraph->getQueryNode(nodePos);
if (corrExprsSet.contains(queryNode->getInternalID())) {
Expand Down
Loading

0 comments on commit c93bbd9

Please sign in to comment.