Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve subquery planning #4651

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/include/common/data_chunk/sel_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class SelectionVector {
enum class State {
DYNAMIC,
STATIC,
STATIC_FILTERED,
};

public:
Expand All @@ -50,9 +49,12 @@ class SelectionVector {
}
void setRange(sel_t startPos, sel_t size) {
KU_ASSERT(startPos + size <= capacity);
selectedPositions = const_cast<sel_t*>(INCREMENTAL_SELECTED_POS.data()) + startPos;
selectedPositions = selectedPositionsBuffer.get();
for (auto i = 0u; i < size; ++i) {
selectedPositions[i] = startPos + i;
}
selectedSize = size;
state = State::STATIC_FILTERED;
state = State::DYNAMIC;
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
}

// Set to filtered is not very accurate. It sets selectedPositions to a mutable array.
Expand Down
2 changes: 1 addition & 1 deletion src/include/common/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class KUZU_API StringUtils {
}
static std::string_view rtrim(std::string_view input) {
auto end = input.size();
while (end > 0 && isspace(input[end - 1])) {
while (end > 0 && isSpace(input[end - 1])) {
end--;
}
return input.substr(0, end);
Expand Down
4 changes: 4 additions & 0 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ class CostModel {
static uint64_t computeExtendCost(const LogicalPlan& childPlan);
static uint64_t computeRecursiveExtendCost(uint8_t upperBound, double extensionRate,
const LogicalPlan& childPlan);
static uint64_t computeHashJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeIntersectCost(const LogicalPlan& probePlan,
Expand Down
20 changes: 13 additions & 7 deletions src/include/planner/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ namespace planner {

struct LogicalInsertInfo;

enum class SubqueryType : uint8_t {
enum class SubqueryPlanningType : uint8_t {
NONE = 0,
INTERNAL_ID_CORRELATED = 1,
UNNEST_CORRELATED = 1,
CORRELATED = 2,
};

struct QueryGraphPlanningInfo {
// Predicate info.
binder::expression_vector predicates;
// Subquery info.
SubqueryType subqueryType = SubqueryType::NONE;
SubqueryPlanningType subqueryType = SubqueryPlanningType::NONE;
binder::expression_vector corrExprs;
cardinality_t corrExprsCard = 0;
// Join hint info.
Expand Down Expand Up @@ -270,12 +270,18 @@ class Planner {
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan);
void appendAccHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan);
void appendHashJoin(const std::vector<binder::expression_pair>& joinConditions,
common::JoinType joinType, std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendAccHashJoin(const std::vector<binder::expression_pair>& joinConditions,
common::JoinType joinType, std::shared_ptr<binder::Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<binder::Expression>& mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan);
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendMarkJoin(const std::vector<binder::expression_pair>& joinConditions,
const std::shared_ptr<binder::Expression>& mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendIntersect(const std::shared_ptr<binder::Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
Expand Down
22 changes: 22 additions & 0 deletions src/planner/join_order/cost_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ uint64_t CostModel::computeRecursiveExtendCost(uint8_t upperBound, double extens
upperBound;
}

binder::expression_vector getJoinNodeIDs(
const std::vector<binder::expression_pair>& joinConditions) {
binder::expression_vector joinNodeIDs;
for (auto& [left, _] : joinConditions) {
if (left->expressionType == ExpressionType::PROPERTY &&
left->getDataType().getLogicalTypeID() == LogicalTypeID::INTERNAL_ID) {
joinNodeIDs.push_back(left);
}
}
return joinNodeIDs;
}

uint64_t CostModel::computeHashJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build) {
return computeHashJoinCost(getJoinNodeIDs(joinConditions), probe, build);
}

uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build) {
uint64_t cost = 0ul;
Expand All @@ -29,6 +46,11 @@ uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNod
return cost;
}

uint64_t CostModel::computeMarkJoinCost(const std::vector<binder::expression_pair>& joinConditions,
const LogicalPlan& probe, const LogicalPlan& build) {
return computeMarkJoinCost(getJoinNodeIDs(joinConditions), probe, build);
}

uint64_t CostModel::computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build) {
return computeHashJoinCost(joinNodeIDs, probe, build);
Expand Down
29 changes: 21 additions & 8 deletions src/planner/plan/append_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType join
for (auto& joinNodeID : joinNodeIDs) {
joinConditions.emplace_back(joinNodeID, joinNodeID);
}
appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan);
}

void Planner::appendHashJoin(const std::vector<expression_pair>& joinConditions, JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
auto hashJoin = make_shared<LogicalHashJoin>(joinConditions, joinType, mark,
probePlan.getLastOperator(), buildPlan.getLastOperator());
// Apply flattening to probe side
Expand All @@ -38,26 +44,33 @@ void Planner::appendHashJoin(const expression_vector& joinNodeIDs, JoinType join
// Update cost
hashJoin->setCardinality(cardinalityEstimator.estimateHashJoin(joinConditions,
probePlan.getLastOperatorRef(), buildPlan.getLastOperatorRef()));
resultPlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan));
resultPlan.setCost(CostModel::computeHashJoinCost(joinConditions, probePlan, buildPlan));
resultPlan.setLastOperator(std::move(hashJoin));
}

void Planner::appendAccHashJoin(const expression_vector& joinNodeIDs, JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
void Planner::appendAccHashJoin(const std::vector<binder::expression_pair>& joinConditions,
JoinType joinType, std::shared_ptr<Expression> mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan, LogicalPlan& resultPlan) {
KU_ASSERT(probePlan.hasUpdate());
tryAppendAccumulate(probePlan);
appendHashJoin(joinNodeIDs, joinType, mark, probePlan, buildPlan, resultPlan);
appendHashJoin(joinConditions, joinType, mark, probePlan, buildPlan, resultPlan);
auto& sipInfo = probePlan.getLastOperator()->cast<LogicalHashJoin>().getSIPInfoUnsafe();
sipInfo.direction = SIPDirection::PROBE_TO_BUILD;
}

void Planner::appendMarkJoin(const expression_vector& joinNodeIDs,
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
std::vector<join_condition_t> joinConditions;
for (auto& joinNodeID : joinNodeIDs) {
joinConditions.emplace_back(joinNodeID, joinNodeID);
}
appendMarkJoin(joinConditions, mark, probePlan, buildPlan, resultPlan);
}

void Planner::appendMarkJoin(const std::vector<expression_pair>& joinConditions,
const std::shared_ptr<Expression>& mark, LogicalPlan& probePlan, LogicalPlan& buildPlan,
LogicalPlan& resultPlan) {
auto hashJoin = make_shared<LogicalHashJoin>(joinConditions, JoinType::MARK, mark,
probePlan.getLastOperator(), buildPlan.getLastOperator());
// Apply flattening to probe side
Expand All @@ -69,8 +82,8 @@ void Planner::appendMarkJoin(const expression_vector& joinNodeIDs,
hashJoin->computeFactorizedSchema();
// update cost. Mark join does not change cardinality.
hashJoin->setCardinality(probePlan.getCardinality());
probePlan.setCost(CostModel::computeMarkJoinCost(joinNodeIDs, probePlan, buildPlan));
probePlan.setLastOperator(std::move(hashJoin));
resultPlan.setCost(CostModel::computeMarkJoinCost(joinConditions, probePlan, buildPlan));
resultPlan.setLastOperator(std::move(hashJoin));
}

void Planner::appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
Expand Down
19 changes: 10 additions & 9 deletions src/planner/plan/plan_join_order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::enumerateQueryGraphCollection
auto& corrExprs = info.corrExprs;
auto corrExprsSet = binder::expression_set{corrExprs.begin(), corrExprs.end()};
int32_t queryGraphIdxToPlanExpressionsScan = -1;
if (info.subqueryType == SubqueryType::CORRELATED) {
if (info.subqueryType == SubqueryPlanningType::CORRELATED) {
// Pick a query graph to plan ExpressionsScan. If -1 is returned, we fall back to cross
// product.
queryGraphIdxToPlanExpressionsScan =
Expand Down Expand Up @@ -80,17 +80,17 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::enumerateQueryGraphCollection
auto newInfo = info;
newInfo.predicates = predicatesToEvaluate;
switch (info.subqueryType) {
case SubqueryType::NONE:
case SubqueryType::INTERNAL_ID_CORRELATED: {
case SubqueryPlanningType::NONE:
case SubqueryPlanningType::UNNEST_CORRELATED: {
plans = enumerateQueryGraph(*queryGraph, newInfo);
} break;
case SubqueryType::CORRELATED: {
case SubqueryPlanningType::CORRELATED: {
if (i == (uint32_t)queryGraphIdxToPlanExpressionsScan) {
// Plan ExpressionsScan with current query graph.
plans = enumerateQueryGraph(*queryGraph, newInfo);
} else {
// Plan current query graph as an isolated query graph.
newInfo.subqueryType = SubqueryType::NONE;
newInfo.subqueryType = SubqueryPlanningType::NONE;
plans = enumerateQueryGraph(*queryGraph, newInfo);
}
} break;
Expand All @@ -101,7 +101,8 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::enumerateQueryGraphCollection
}
// Fail to plan ExpressionsScan with any query graph. Plan it independently and fall back to
// cross product.
if (info.subqueryType == SubqueryType::CORRELATED && queryGraphIdxToPlanExpressionsScan == -1) {
if (info.subqueryType == SubqueryPlanningType::CORRELATED &&
queryGraphIdxToPlanExpressionsScan == -1) {
auto plan = std::make_unique<LogicalPlan>();
appendExpressionsScan(corrExprs, *plan);
appendDistinct(corrExprs, *plan);
Expand Down Expand Up @@ -186,12 +187,12 @@ void Planner::planBaseTableScans(const QueryGraphPlanningInfo& info) {
auto& corrExprs = info.corrExprs;
auto corrExprsSet = expression_set{corrExprs.begin(), corrExprs.end()};
switch (info.subqueryType) {
case SubqueryType::NONE: {
case SubqueryPlanningType::NONE: {
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
planNodeScan(nodePos);
}
} break;
case SubqueryType::INTERNAL_ID_CORRELATED: {
case SubqueryPlanningType::UNNEST_CORRELATED: {
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
auto queryNode = queryGraph->getQueryNode(nodePos);
if (corrExprsSet.contains(queryNode->getInternalID())) {
Expand All @@ -205,7 +206,7 @@ void Planner::planBaseTableScans(const QueryGraphPlanningInfo& info) {
}
}
} break;
case SubqueryType::CORRELATED: {
case SubqueryPlanningType::CORRELATED: {
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
auto queryNode = queryGraph->getQueryNode(nodePos);
if (corrExprsSet.contains(queryNode->getInternalID())) {
Expand Down
Loading