Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scheduler #395

Merged
merged 9 commits into from
Dec 29, 2023
10 changes: 10 additions & 0 deletions src/common/blocking_queue.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ public:
full_cv_.notify_one();
}

bool TryDequeueBulk(Vector<T> &output_array) {
UniqueLock<Mutex> lock(queue_mutex_);
if (queue_.empty()) {
return false;
}
output_array.insert(output_array.end(), queue_.begin(), queue_.end());
queue_.clear();
full_cv_.notify_one();
}

[[nodiscard]] SizeT Size() const {
LockGuard<Mutex> lock(queue_mutex_);
return queue_.size();
Expand Down
3 changes: 3 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ export {
template <typename S, typename T, typename H = std::hash<S>>
using HashMap = std::unordered_map<S, T, H>;

template <typename S, typename T, typename H = std::hash<S>>
using MultiHashMap = std::unordered_multimap<S, T, H>;

template <typename S>
using HashSet = std::unordered_set<S>;

Expand Down
2 changes: 2 additions & 0 deletions src/executor/fragment/plan_fragment.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public:

[[nodiscard]] inline PhysicalSink *GetSinkNode() const { return sink_.get(); }

[[nodiscard]] inline PlanFragment *GetParent() const { return parent_; }

inline void AddChild(UniquePtr<PlanFragment> child_fragment) {
child_fragment->parent_ = this;
children_.emplace_back(Move(child_fragment));
Expand Down
9 changes: 3 additions & 6 deletions src/executor/operator/physical_create_index_finish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,10 @@ bool PhysicalCreateIndexFinish::Execute(QueryContext *query_context, OperatorSta
auto *txn = query_context->GetTxn();
auto *create_index_finish_op_state = static_cast<CreateIndexFinishOperatorState *>(operator_state);

if (create_index_finish_op_state->input_complete_) {
txn->AddWalCmd(MakeShared<WalCmdCreateIndex>(*db_name_, *table_name_, index_def_));
txn->AddWalCmd(MakeShared<WalCmdCreateIndex>(*db_name_, *table_name_, index_def_));

operator_state->SetComplete();
return true;
}
return false;
operator_state->SetComplete();
return true;
}

} // namespace infinity
10 changes: 10 additions & 0 deletions src/executor/operator/physical_sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,16 @@ void PhysicalSink::FillSinkStateFromLastOperatorState(MessageSinkState *message_
message_sink_state->message_ = Move(insert_output_state->result_msg_);
break;
}
case PhysicalOperatorType::kCreateIndexPrepare: {
auto *create_index_prepare_output_state = static_cast<CreateIndexPrepareOperatorState *>(task_operator_state);
message_sink_state->message_ = Move(create_index_prepare_output_state->result_msg_);
break;
}
case PhysicalOperatorType::kCreateIndexDo: {
auto *create_index_do_output_state = static_cast<CreateIndexDoOperatorState *>(task_operator_state);
message_sink_state->message_ = Move(create_index_do_output_state->result_msg_);
break;
}
default: {
Error<NotImplementException>(Format("{} isn't supported here.", PhysicalOperatorToString(task_operator_state->operator_type_)));
break;
Expand Down
9 changes: 0 additions & 9 deletions src/executor/operator/physical_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,4 @@ bool PhysicalSource::Execute(QueryContext *, SourceState *source_state) {
return true;
}

bool PhysicalSource::ReadyToExec(SourceState *source_state) {
bool result = true;
if (source_state->state_type_ == SourceStateType::kQueue) {
QueueSourceState *queue_source_state = static_cast<QueueSourceState *>(source_state);
result = queue_source_state->source_queue_.Size() > 0;
}
return result;
}

} // namespace infinity
2 changes: 0 additions & 2 deletions src/executor/operator/physical_source.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ public:

bool Execute(QueryContext *query_context, SourceState *source_state);

bool ReadyToExec(SourceState *source_state);

inline SharedPtr<Vector<String>> GetOutputNames() const final { return output_names_; }

inline SharedPtr<Vector<SharedPtr<DataType>>> GetOutputTypes() const final { return output_types_; }
Expand Down
10 changes: 0 additions & 10 deletions src/executor/operator_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,6 @@ bool QueueSourceState::GetData() {
fusion_op_state->input_complete_ = completed;
break;
}
case PhysicalOperatorType::kCreateIndexDo: {
auto *create_index_do_op_state = static_cast<CreateIndexDoOperatorState *>(next_op_state);
create_index_do_op_state->input_complete_ = completed;
break;
}
case PhysicalOperatorType::kCreateIndexFinish: {
auto *create_index_finish_op_state = static_cast<CreateIndexFinishOperatorState *>(next_op_state);
create_index_finish_op_state->input_complete_ = completed;
break;
}
case PhysicalOperatorType::kMergeLimit: {
auto *fragment_data = static_cast<FragmentData *>(fragment_data_base.get());
MergeLimitOperatorState *limit_op_state = (MergeLimitOperatorState *)next_op_state;
Expand Down
3 changes: 1 addition & 2 deletions src/executor/operator_state.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,13 @@ export struct CreateIndexPrepareOperatorState : public OperatorState {
export struct CreateIndexDoOperatorState : public OperatorState {
inline explicit CreateIndexDoOperatorState() : OperatorState(PhysicalOperatorType::kCreateIndexDo) {}

bool input_complete_ = false;
UniquePtr<String> result_msg_{};
CreateIndexSharedData *create_index_shared_data_;
};

export struct CreateIndexFinishOperatorState : public OperatorState {
inline explicit CreateIndexFinishOperatorState() : OperatorState(PhysicalOperatorType::kCreateIndexFinish) {}

bool input_complete_ = false;
UniquePtr<String> error_message_{};
};

Expand Down
2 changes: 1 addition & 1 deletion src/executor/physical_operator_type.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ export enum class PhysicalOperatorType : i8 {
kInsert,
kImport,
kExport,
kCreateIndexDo,

// DDL
kAlter,
Expand All @@ -86,6 +85,7 @@ export enum class PhysicalOperatorType : i8 {
kDropView,

kCreateIndexPrepare,
kCreateIndexDo,
kCreateIndexFinish,

// misc
Expand Down
5 changes: 2 additions & 3 deletions src/main/query_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,11 @@ QueryResult QueryContext::QueryStatement(const BaseStatement *statement) {
StopProfile(QueryPhase::kPipelineBuild);

StartProfile(QueryPhase::kTaskBuild);
Vector<FragmentTask *> tasks;
FragmentContext::BuildTask(this, nullptr, plan_fragment.get(), tasks);
FragmentContext::BuildTask(this, nullptr, plan_fragment.get());
StopProfile(QueryPhase::kTaskBuild);

StartProfile(QueryPhase::kExecution);
scheduler_->Schedule(this, tasks, plan_fragment.get());
scheduler_->Schedule(plan_fragment.get());
query_result.result_table_ = plan_fragment->GetResult();
query_result.root_operator_type_ = logical_plan->operator_type();
StopProfile(QueryPhase::kExecution);
Expand Down
70 changes: 45 additions & 25 deletions src/scheduler/fragment_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import physical_merge_knn;
import merge_knn_data;
import create_index_data;
import logger;

import task_scheduler;
import plan_fragment;

module fragment_context;
Expand Down Expand Up @@ -326,10 +326,7 @@ void CollectTasks(Vector<SharedPtr<String>> &result, PlanFragment *fragment_ptr)
}
}

void FragmentContext::BuildTask(QueryContext *query_context,
FragmentContext *parent_context,
PlanFragment *fragment_ptr,
Vector<FragmentTask *> &task_array) {
void FragmentContext::BuildTask(QueryContext *query_context, FragmentContext *parent_context, PlanFragment *fragment_ptr) {
Vector<PhysicalOperator *> &fragment_operators = fragment_ptr->GetOperators();
i64 operator_count = fragment_operators.size();
if (operator_count < 1) {
Expand Down Expand Up @@ -419,7 +416,7 @@ void FragmentContext::BuildTask(QueryContext *query_context,
if (fragment_ptr->HasChild()) {
// current fragment have children
for (const auto &child_fragment : fragment_ptr->Children()) {
FragmentContext::BuildTask(query_context, fragment_context.get(), child_fragment.get(), task_array);
FragmentContext::BuildTask(query_context, fragment_context.get(), child_fragment.get());
}
}
switch (fragment_operators[0]->operator_type()) {
Expand All @@ -430,34 +427,48 @@ void FragmentContext::BuildTask(QueryContext *query_context,
if (explain_op->explain_type() == ExplainType::kPipeline) {
CollectTasks(result, fragment_ptr->Children()[0].get());
explain_op->SetExplainTaskText(MakeShared<Vector<SharedPtr<String>>>(result));
task_array.clear();
break;
}
}
default:
break;
}

for (const auto &task : tasks) {
task_array.emplace_back(task.get());
}

fragment_ptr->SetContext(Move(fragment_context));
}

FragmentContext::FragmentContext(PlanFragment *fragment_ptr, QueryContext *query_context)
: fragment_ptr_(fragment_ptr), fragment_type_(fragment_ptr->GetFragmentType()), query_context_(query_context){};
: fragment_ptr_(fragment_ptr), query_context_(query_context), fragment_type_(fragment_ptr->GetFragmentType()),
fragment_status_(FragmentStatus::kNotStart), unfinished_child_n_(fragment_ptr->Children().size()) {}

void FragmentContext::FinishTask() {
u64 unfinished_task = task_n_.fetch_sub(1);
auto sink_op = GetSinkOperator();
void FragmentContext::TryFinishFragment() {
if (!TryFinishFragmentInner()) {
LOG_TRACE(Format("{} tasks in fragment are not completed: {} are not completed", unfinished_task_n_.load(), fragment_ptr_->FragmentID()));
return;
}
LOG_TRACE(Format("All tasks in fragment: {} are completed", fragment_ptr_->FragmentID()));
fragment_status_ = FragmentStatus::kFinish;

if (unfinished_task == 1 && sink_op->sink_type() == SinkType::kResult) {
LOG_TRACE(Format("All tasks in fragment: {} are completed", fragment_ptr_->FragmentID()));
auto *sink_op = GetSinkOperator();
if (sink_op->sink_type() == SinkType::kResult) {
Complete();
} else {
LOG_TRACE(Format("Not all tasks in fragment: {} are completed", fragment_ptr_->FragmentID()));
return;
}

// Try to schedule parent fragment
auto *parent_plan_fragment = fragment_ptr_->GetParent();
if (parent_plan_fragment == nullptr) {
return;
}
auto *parent_fragment_ctx = parent_plan_fragment->GetContext();

if (!parent_fragment_ctx->TryStartFragment() && parent_fragment_ctx->fragment_type_ != FragmentType::kParallelStream) {
return;
}
// All child fragment are finished.
auto *scheduler = query_context_->scheduler();
scheduler->ScheduleFragment(parent_plan_fragment);
return;
}

Vector<PhysicalOperator *> &FragmentContext::GetOperators() { return fragment_ptr_->GetOperators(); }
Expand Down Expand Up @@ -561,8 +572,7 @@ void FragmentContext::MakeSourceState(i64 parallel_count) {
case PhysicalOperatorType::kMergeTop:
case PhysicalOperatorType::kMergeSort:
case PhysicalOperatorType::kMergeKnn:
case PhysicalOperatorType::kFusion:
case PhysicalOperatorType::kCreateIndexFinish: {
case PhysicalOperatorType::kFusion: {
if (fragment_type_ != FragmentType::kSerialMaterialize) {
Error<SchedulerException>(
Format("{} should be serial materialized fragment", PhysicalOperatorToString(first_operator->operator_type())));
Expand All @@ -581,7 +591,7 @@ void FragmentContext::MakeSourceState(i64 parallel_count) {
Format("{} should in parallel materialized fragment", PhysicalOperatorToString(first_operator->operator_type())));
}
for (auto &task : tasks_) {
task->source_state_ = MakeUnique<QueueSourceState>();
task->source_state_ = MakeUnique<EmptySourceState>();
}
break;
}
Expand Down Expand Up @@ -634,6 +644,7 @@ void FragmentContext::MakeSourceState(i64 parallel_count) {
case PhysicalOperatorType::kCreateTable:
case PhysicalOperatorType::kCreateIndex:
case PhysicalOperatorType::kCreateIndexPrepare:
case PhysicalOperatorType::kCreateIndexFinish:
case PhysicalOperatorType::kCreateCollection:
case PhysicalOperatorType::kCreateDatabase:
case PhysicalOperatorType::kCreateView:
Expand Down Expand Up @@ -754,9 +765,7 @@ void FragmentContext::MakeSinkState(i64 parallel_count) {
break;
}
case PhysicalOperatorType::kSort:
case PhysicalOperatorType::kKnnScan:
case PhysicalOperatorType::kCreateIndexPrepare:
case PhysicalOperatorType::kCreateIndexDo: {
case PhysicalOperatorType::kKnnScan: {
if (fragment_type_ != FragmentType::kParallelMaterialize && fragment_type_ != FragmentType::kSerialMaterialize) {
Error<SchedulerException>(
Format("{} should in parallel/serial materialized fragment", PhysicalOperatorToString(first_operator->operator_type())));
Expand Down Expand Up @@ -835,6 +844,7 @@ void FragmentContext::MakeSinkState(i64 parallel_count) {
}
break;
}
case PhysicalOperatorType::kCreateIndexPrepare:
case PhysicalOperatorType::kInsert:
case PhysicalOperatorType::kImport:
case PhysicalOperatorType::kExport: {
Expand All @@ -850,6 +860,16 @@ void FragmentContext::MakeSinkState(i64 parallel_count) {
tasks_[0]->sink_state_ = MakeUnique<MessageSinkState>();
break;
}
case PhysicalOperatorType::kCreateIndexDo: {
if (fragment_type_ != FragmentType::kParallelMaterialize) {
Error<SchedulerException>(
Format("{} should in parallel materialized fragment", PhysicalOperatorToString(last_operator->operator_type())));
}
for (auto &task : tasks_) {
task->sink_state_ = MakeUnique<MessageSinkState>();
}
break;
}
case PhysicalOperatorType::kCommand:
case PhysicalOperatorType::kCreateTable:
case PhysicalOperatorType::kCreateIndex:
Expand Down
Loading
Loading