Skip to content

Commit

Permalink
[AllReduce] make AllReduce tasks concurrent in FlexFlow (#1517)
Browse files Browse the repository at this point in the history
* minor bug fix

* make AllReduce tasks concurrent

* set concurrent=true for remaining operators

---------

Co-authored-by: Gabriele Oliaro <[email protected]>
  • Loading branch information
jiazhihao and goliaro authored Oct 5, 2024
1 parent c78cf04 commit ca3dabf
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/ops/fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ void FusedOp::init(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
FutureMap fm = runtime->execute_index_space(ctx, launcher);
fm.wait_all_results();
switch (domain.get_dim()) {
Expand Down Expand Up @@ -570,6 +571,7 @@ void FusedOp::init_inference(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
FutureMap fm = runtime->execute_index_space(ctx, launcher);
fm.wait_all_results();
switch (domain.get_dim()) {
Expand Down Expand Up @@ -604,6 +606,7 @@ void FusedOp::forward(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
int offset = 0;
for (int i = 0; i < numInputs; i++) {
assert(inputs[i]->part != LogicalPartition::NO_PART);
Expand Down Expand Up @@ -659,6 +662,7 @@ FutureMap FusedOp::inference(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_future(bc);
int offset = 0;
for (int i = 0; i < numInputs; i++) {
Expand Down Expand Up @@ -735,6 +739,7 @@ FutureMap FusedOp::peft_bwd(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_future(bc);
int offset = 0;
for (int i = 0; i < numInputs; i++) {
Expand Down Expand Up @@ -787,6 +792,7 @@ void FusedOp::backward(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
int idx = 0;
for (int i = 0; i < numInputs; i++) {
launcher.add_region_requirement(RegionRequirement(inputs[i]->part,
Expand Down
2 changes: 2 additions & 0 deletions src/ops/lora_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ void LoraLinear::init_inference(
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -795,6 +796,7 @@ FutureMap LoraLinear::peft_bwd(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_future(bc);
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad,
Expand Down
5 changes: 5 additions & 0 deletions src/parallel_ops/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ void AllReduce::init(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -164,6 +165,7 @@ void AllReduce::forward(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -212,6 +214,7 @@ void AllReduce::backward(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
inputs[0]->machine_view.hash());
// launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
Expand Down Expand Up @@ -265,6 +268,7 @@ void AllReduce::init_inference(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -306,6 +310,7 @@ FutureMap AllReduce::inference(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
Expand Down
4 changes: 4 additions & 0 deletions src/parallel_ops/parallel_identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ void ParallelIdentity::init(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -214,6 +215,7 @@ void ParallelIdentity::backward(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
inputs[0]->machine_view.hash());
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad,
0 /*projection id*/,
READ_WRITE,
Expand Down Expand Up @@ -268,6 +270,7 @@ void ParallelIdentity::init_inference(
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -381,6 +384,7 @@ FutureMap
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_future(bc);
launcher.add_region_requirement(
RegionRequirement(batch_inputs[0]->part_grad,
Expand Down
23 changes: 23 additions & 0 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6888,6 +6888,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(LORA_LINEAR_INIT_TASK_ID, "LoraLinear Init");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<OpMeta *, LoraLinear::init_task>(
registrar, "LoraLinear Init Task");
Expand Down Expand Up @@ -6919,6 +6920,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"LoraLinear PEFT Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<LoraLinear::peft_bwd_task>(
registrar, "LoraLinear PEFT Backward Task");
Expand Down Expand Up @@ -6950,6 +6952,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(FUSEDOP_INIT_TASK_ID, "FusedOp Init");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<OpMeta *, FusedOp::init_task>(
registrar, "FusedOp Init Task");
Expand All @@ -6964,6 +6967,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(FUSEDOP_INF_TASK_ID, "FusedOp Inference");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<FusedOp::inference_task>(
registrar, "FusedOp Inference Task");
Expand All @@ -6979,6 +6983,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"FusedOp PEFT Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<FusedOp::peft_bwd_task>(
registrar, "FusedOp PEFT Backward Task");
Expand All @@ -6994,6 +6999,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(FUSEDOP_FWD_TASK_ID, "FusedOp Forward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<FusedOp::forward_task>(
registrar, "FusedOp Forward Task");
Expand All @@ -7008,6 +7014,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(FUSEDOP_BWD_TASK_ID, "FusedOp Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<FusedOp::backward_task>(
registrar, "FusedOp Backward Task");
Expand Down Expand Up @@ -7244,6 +7251,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<OpMeta *, AllReduce::init_task>(
registrar, "AllReduce init Task");
Expand All @@ -7258,6 +7266,9 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(ALLREDUCE_FWD_TASK_ID, "AllReduce Forward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
// AllReduce forward and backward must run concurrently since they
// use ncclAllReduce internally
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<AllReduce::forward_task>(
registrar, "AllReduce Forward Task");
Expand All @@ -7272,6 +7283,9 @@ void register_flexflow_internal_tasks(Runtime *runtime,
TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
// AllReduce forward and backward must run concurrently since they
// use ncclAllReduce internally
// registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<AllReduce::backward_task>(
registrar, "AllReduce Backward Task");
Expand All @@ -7287,6 +7301,9 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"AllReduce Inference");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
// AllReduce forward and backward must run concurrently since they
// use ncclAllReduce internally
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<AllReduce::inference_task>(
registrar, "AllReduce Inference Task");
Expand All @@ -7302,6 +7319,9 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"AllReduce PEFT Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
// AllReduce forward and backward must run concurrently since they
// use ncclAllReduce internally
// registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<AllReduce::peft_bwd_task>(
registrar, "AllReduce PEFT Backward Task");
Expand All @@ -7318,6 +7338,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"ParallelIdentity Init");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<OpMeta *, ParallelIdentity::init_task>(
registrar, "ParallelIdentity init Task");
Expand Down Expand Up @@ -7349,6 +7370,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"ParallelIdentity Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<ParallelIdentity::backward_task>(
registrar, "ParallelIdentity Backward Task");
Expand Down Expand Up @@ -7381,6 +7403,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
"ParallelIdentity PEFT Backward");
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
if (pre_register) {
Runtime::preregister_task_variant<ParallelIdentity::peft_bwd_task>(
registrar, "ParallelIdentity PEFT Backward Task");
Expand Down

0 comments on commit ca3dabf

Please sign in to comment.