Skip to content

Commit

Permalink
Initial sharding support
Browse files Browse the repository at this point in the history
-shards [N] -- do a simple sharding based on topological ordering of
DAG.
  • Loading branch information
Weiming Zhao committed Jul 29, 2021
1 parent 7505ead commit 916c8ef
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 1 deletion.
4 changes: 3 additions & 1 deletion driver/driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ static llvm::cl::opt<int> IpuNum(
"ipu-num",
llvm::cl::desc("Num of ipus, should consistent with subgraph num"),
llvm::cl::init(1), llvm::cl::cat(HaloOptCat));

static llvm::cl::opt<int> Shards("shards", llvm::cl::desc("Num of shards"),
llvm::cl::init(-1), llvm::cl::cat(HaloOptCat));
static llvm::cl::opt<int> BatchesPerStep(
"batches-per-step", llvm::cl::desc("Specify batches num for each step"),
llvm::cl::init(1), llvm::cl::cat(HaloOptCat));
Expand Down Expand Up @@ -358,6 +359,7 @@ int main(int argc, char** argv) {
cg_opts.enable_ipu_device = EnableIpuDevice;
cg_opts.use_ipu_model = UseIpuModel;
cg_opts.ipu_num = IpuNum;
cg_opts.num_shards = Shards;
cg_opts.batches_per_step = BatchesPerStep;
cg_opts.api = Api;
cg_opts.disable_broadcasting = DisableBroadcasting;
Expand Down
1 change: 1 addition & 0 deletions include/halo/halo.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct CXXCodeGenOpts {
bool emit_shared_lib = false;
const char* linked_odla_lib = nullptr;
bool save_temps = false;
int num_shards = -1;
};

#define HALO_MODEL_INFO_MAX_OUTPUT_NR 64
Expand Down
1 change: 1 addition & 0 deletions include/halo/lib/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class HL_API_EXPORT PassManager final {

Pass* AddRISCVLLVMIRCodeGenPass(ConstantDataStorage constant_data_storage,
const std::string& rt_lib_name);
Pass* AddShardingPass(int num_shards);
Pass* AddSplittingPass();
Pass* AddTFExtensionLegalizerPass();
Pass* AddTFLiteExtensionLegalizerPass();
Expand Down
38 changes: 38 additions & 0 deletions include/halo/lib/transforms/sharding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- splitting.h --------------------------------------------------------===//
//
// Copyright (C) 2019-2020 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef HALO_LIB_TRANSFORMS_SHARDING_H_
#define HALO_LIB_TRANSFORMS_SHARDING_H_

#include "halo/lib/pass/pass.h"

namespace halo {

/// This pass splitting a function into sub functions.
class Sharding final : public FunctionPass {
public:
Sharding(int shards) : FunctionPass("Sharding"), shards_(shards) {}

bool RunOnFunction(Function* func) override;

private:
int shards_;
};

} // end namespace halo.

#endif // HALO_LIB_TRANSFORMS_SHARDING_H_
1 change: 1 addition & 0 deletions include/halo/utils/passes_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target,
if (opts.enable_type_cast) {
pm->AddTypeCastPass();
}
pm->AddShardingPass(opts.num_shards);
}
} // namespace halo

Expand Down
5 changes: 5 additions & 0 deletions lib/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "halo/lib/transforms/onnxextension_legalizer.h"
#include "halo/lib/transforms/output_rewriter.h"
#include "halo/lib/transforms/reorder_channel.h"
#include "halo/lib/transforms/sharding.h"
#include "halo/lib/transforms/splitting.h"
#include "halo/lib/transforms/tfextension_legalizer.h"
#include "halo/lib/transforms/tfliteextension_legalizer.h"
Expand Down Expand Up @@ -382,6 +383,10 @@ Pass* PassManager::AddRISCVLLVMIRCodeGenPass(
return AddPass<RISCVLLVMIRCodeGen>(constant_data_storage);
}

Pass* PassManager::AddShardingPass(int num_shards) {
return AddPass<Sharding>(num_shards);
}

Pass* PassManager::AddSplittingPass() { return AddPass<Splitting>(); }

Pass* PassManager::AddTFExtensionLegalizerPass() {
Expand Down
1 change: 1 addition & 0 deletions lib/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(SRCS
onnxextension_legalizer.cc
output_rewriter.cc
reorder_channel.cc
sharding.cc
splitting.cc
tfextension_legalizer.cc
tfliteextension_legalizer.cc
Expand Down
133 changes: 133 additions & 0 deletions lib/transforms/sharding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- sharding.cc --------------------------------------------------------===//
//
// Copyright (C) 2019-2021 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "halo/lib/transforms/sharding.h"

#include <unordered_map>
#include <unordered_set>

#include "halo/lib/ir/ir_builder.h"

namespace halo {

static unsigned CountNonConstOps(const Function& func) {
unsigned ret = func.Args().size();
for (auto& bb : func) {
for (auto& ir : *bb) {
ret += IsA<Constant>(*ir) ? 0 : 1;
}
}
return ret;
}

// Simple sharding scheme: try to get equal shards of op.
static std::unordered_map<IRObject*, int> GetSimpleSharding(
const Function& func, unsigned num_shards) {
std::unordered_map<IRObject*, int> shardings;
auto num_ops = CountNonConstOps(func);
const unsigned threshold = (num_shards == 0)
? std::numeric_limits<unsigned>::max()
: (num_ops + num_shards - 1) / num_shards;
unsigned curr_shard = 0;
unsigned allocated = 0;
auto is_avail = [&shardings](const IRObject* n) {
for (auto& op : n->GetOperands()) {
if (!IsA<Constant>(op) && shardings.count(op.GetOwner()) == 0) {
return false;
}
}
return true;
};

// BFS visit.
std::unordered_set<IRObject*> workset;
for (auto& arg : func.Args()) {
workset.insert(arg.get());
}
for (auto& bb : func) {
for (auto& ir : *bb) {
if (!IsA<Constant>(ir.get()) && is_avail(ir.get())) {
workset.insert(ir.get());
}
}
}

while (!workset.empty()) {
std::unordered_set<IRObject*> next;
HLCHECK(curr_shard < num_shards);
for (auto& node : workset) {
HLCHECK(shardings.count(node) == 0);
shardings[node] = curr_shard;
if (!IsA<Constant>(node)) {
++allocated;
}
for (auto& op : node->GetOperands()) {
if (shardings.count(op.GetOwner()) == 0 && IsA<Constant>(op)) {
shardings[op.GetOwner()] = curr_shard;
}
}
}
// check successors
for (auto& node : workset) {
for (auto& uses : node->GetResultsUses()) {
for (auto& user : uses) {
auto n = user.GetUse();
if (is_avail(n) && shardings.count(n) == 0) {
next.insert(n);
}
}
}
}
if (allocated >= threshold) {
++curr_shard;
allocated = 0;
for (auto& n : next) {
std::cout << " ---- cut:" << n->GetName() << std::endl;
}
}
workset.swap(next);
}

for (auto& bb : func) {
for (auto& ir : *bb) {
auto inst = ir.get();
HLCHECK(shardings.count(inst) != 0);
if (shardings.count(inst) == 0) {
std::cerr << "[sharding] unallocated: " << inst->GetName() << "\n";
shardings[inst] = num_shards - 1;
}
}
}
return shardings;
}

static void ApplySharding(const std::unordered_map<IRObject*, int>& shardings) {
for (const auto& kv : shardings) {
std::cout << kv.second << ":" << kv.first->GetName() << ":" << std::endl;
}
}

bool Sharding::RunOnFunction(Function* func) {
if (shards_ <= 0) {
return false;
}
auto sharding_scheme = GetSimpleSharding(*func, shards_);
ApplySharding(sharding_scheme);
return false;
}

} // end namespace halo

0 comments on commit 916c8ef

Please sign in to comment.