From ca3dabf7d23cf2173fca830249c4cb9eeb6171bf Mon Sep 17 00:00:00 2001 From: Zhihao Jia Date: Sat, 5 Oct 2024 11:36:34 -0700 Subject: [PATCH] [AllReduce] make AllReduce tasks concurrent in FlexFlow (#1517) * minor bug fix * make AllReduce tasks concurrent * set concurrent=true for remaining operators --------- Co-authored-by: Gabriele Oliaro --- src/ops/fused.cc | 6 ++++++ src/ops/lora_linear.cc | 2 ++ src/parallel_ops/allreduce.cc | 5 +++++ src/parallel_ops/parallel_identity.cc | 4 ++++ src/runtime/model.cc | 23 +++++++++++++++++++++++ 5 files changed, 40 insertions(+) diff --git a/src/ops/fused.cc b/src/ops/fused.cc index 121139beb1..720d678a4a 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -476,6 +476,7 @@ void FusedOp::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); switch (domain.get_dim()) { @@ -570,6 +571,7 @@ void FusedOp::init_inference(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); switch (domain.get_dim()) { @@ -604,6 +606,7 @@ void FusedOp::forward(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + launcher.concurrent = true; int offset = 0; for (int i = 0; i < numInputs; i++) { assert(inputs[i]->part != LogicalPartition::NO_PART); @@ -659,6 +662,7 @@ FutureMap FusedOp::inference(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_future(bc); int offset = 0; for (int i = 0; i < numInputs; i++) { @@ -735,6 +739,7 @@ FutureMap FusedOp::peft_bwd(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_future(bc); int offset = 0; for (int i = 0; i < numInputs; i++) { @@ -787,6 +792,7 @@ void FusedOp::backward(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + launcher.concurrent = true; int idx = 0; for (int i = 0; i < numInputs; i++) { launcher.add_region_requirement(RegionRequirement(inputs[i]->part, diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index fde6bc2b28..513147f3b7 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -296,6 +296,7 @@ void LoraLinear::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -795,6 +796,7 @@ FutureMap LoraLinear::peft_bwd(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_future(bc); launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index 52c4ec2e28..dc43d80133 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -131,6 +131,7 @@ void AllReduce::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -164,6 +165,7 @@ void AllReduce::forward(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -212,6 +214,7 @@ void AllReduce::backward(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, inputs[0]->machine_view.hash()); + // launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, 0 /*projection id*/, READ_WRITE, @@ -265,6 +268,7 @@ void AllReduce::init_inference(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -306,6 +310,7 @@ FutureMap AllReduce::inference(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_future(bc); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, diff --git a/src/parallel_ops/parallel_identity.cc b/src/parallel_ops/parallel_identity.cc index 883910ae09..7d68036709 100644 --- a/src/parallel_ops/parallel_identity.cc +++ b/src/parallel_ops/parallel_identity.cc @@ -133,6 +133,7 @@ void ParallelIdentity::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -214,6 +215,7 @@ void ParallelIdentity::backward(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, inputs[0]->machine_view.hash()); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, 0 /*projection id*/, READ_WRITE, @@ -268,6 +270,7 @@ void ParallelIdentity::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -381,6 +384,7 @@ FutureMap false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.concurrent = true; launcher.add_future(bc); launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 5213633e73..52f1dd2220 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -6888,6 +6888,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(LORA_LINEAR_INIT_TASK_ID, "LoraLinear Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "LoraLinear Init Task"); @@ -6919,6 +6920,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, "LoraLinear PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "LoraLinear PEFT Backward Task"); @@ -6950,6 +6952,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(FUSEDOP_INIT_TASK_ID, "FusedOp Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Init Task"); @@ -6964,6 +6967,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(FUSEDOP_INF_TASK_ID, "FusedOp Inference"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Inference Task"); @@ -6979,6 +6983,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, "FusedOp PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp PEFT Backward Task"); @@ -6994,6 +6999,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(FUSEDOP_FWD_TASK_ID, "FusedOp Forward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Forward Task"); @@ -7008,6 +7014,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(FUSEDOP_BWD_TASK_ID, "FusedOp Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Backward Task"); @@ -7244,6 +7251,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce init Task"); @@ -7258,6 +7266,9 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_FWD_TASK_ID, "AllReduce Forward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + // AllReduce forward and backward must run concurrently since they + // use ncclAllReduce internally + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Forward Task"); @@ -7272,6 +7283,9 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + // AllReduce forward and backward must run concurrently since they + // use ncclAllReduce internally + // registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Backward Task"); @@ -7287,6 +7301,9 @@ void register_flexflow_internal_tasks(Runtime *runtime, "AllReduce Inference"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + // AllReduce forward and backward must run concurrently since they + // use ncclAllReduce internally + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Inference Task"); @@ -7302,6 +7319,9 @@ void register_flexflow_internal_tasks(Runtime *runtime, "AllReduce PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + // AllReduce forward and backward must run concurrently since they + // use ncclAllReduce internally + // registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce PEFT Backward Task"); @@ -7318,6 +7338,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, "ParallelIdentity Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity init Task"); @@ -7349,6 +7370,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, "ParallelIdentity Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity Backward Task"); @@ -7381,6 +7403,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, "ParallelIdentity PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity PEFT Backward Task");